clasificación MNIST

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar libreta

Este tutorial crea una red neuronal cuántica (QNN) para clasificar una versión simplificada de MNIST, similar al enfoque utilizado en Farhi et al . El rendimiento de la red neuronal cuántica en este problema de datos clásico se compara con una red neuronal clásica.

Configuración

pip install tensorflow==2.7.0

Instale TensorFlow Quantum:

pip install tensorflow-quantum
# Update package resources to account for version changes.
import importlib, pkg_resources
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py'>

Ahora importa TensorFlow y las dependencias del módulo:

import tensorflow as tf
import tensorflow_quantum as tfq

import cirq
import sympy
import numpy as np
import seaborn as sns
import collections

# visualization tools
%matplotlib inline
import matplotlib.pyplot as plt
from cirq.contrib.svg import SVGCircuit
2022-02-04 12:29:39.759643: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

1. Cargue los datos

En este tutorial, construirá un clasificador binario para distinguir entre los dígitos 3 y 6, siguiendo a Farhi et al. Esta sección cubre el manejo de datos que:

  • Carga los datos sin procesar de Keras.
  • Filtra el conjunto de datos a solo 3 y 6 segundos.
  • Reduce la escala de las imágenes para que quepan en una computadora cuántica.
  • Elimina cualquier ejemplo contradictorio.
  • Convierte las imágenes binarias en circuitos Cirq.
  • Convierte los circuitos Cirq en circuitos TensorFlow Quantum.

1.1 Cargar los datos sin procesar

Cargue el conjunto de datos MNIST distribuido con Keras.

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Rescale the images from [0,255] to the [0.0,1.0] range.
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

print("Number of original training examples:", len(x_train))
print("Number of original test examples:", len(x_test))
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
Number of original training examples: 60000
Number of original test examples: 10000

Filtre el conjunto de datos para mantener solo los 3 y los 6, elimine las otras clases. Al mismo tiempo, convierta la etiqueta, y , en booleana: True para 3 y False para 6.

def filter_36(x, y):
    keep = (y == 3) | (y == 6)
    x, y = x[keep], y[keep]
    y = y == 3
    return x,y
x_train, y_train = filter_36(x_train, y_train)
x_test, y_test = filter_36(x_test, y_test)

print("Number of filtered training examples:", len(x_train))
print("Number of filtered test examples:", len(x_test))
Number of filtered training examples: 12049
Number of filtered test examples: 1968

Muestra el primer ejemplo:

print(y_train[0])

plt.imshow(x_train[0, :, :, 0])
plt.colorbar()
True
<matplotlib.colorbar.Colorbar at 0x7fac6ad4bd90>

png

1.2 Reducir la escala de las imágenes

Un tamaño de imagen de 28x28 es demasiado grande para las computadoras cuánticas actuales. Cambie el tamaño de la imagen a 4x4:

x_train_small = tf.image.resize(x_train, (4,4)).numpy()
x_test_small = tf.image.resize(x_test, (4,4)).numpy()

Una vez más, muestre el primer ejemplo de entrenamiento, después de cambiar el tamaño:

print(y_train[0])

plt.imshow(x_train_small[0,:,:,0], vmin=0, vmax=1)
plt.colorbar()
True
<matplotlib.colorbar.Colorbar at 0x7fabf807fe10>

png

1.3 Eliminar ejemplos contradictorios

De la sección 3.3 Aprendiendo a distinguir dígitos de Farhi et al. , filtre el conjunto de datos para eliminar las imágenes etiquetadas como pertenecientes a ambas clases.

Este no es un procedimiento estándar de aprendizaje automático, pero se incluye con el fin de seguir el documento.

def remove_contradicting(xs, ys):
    mapping = collections.defaultdict(set)
    orig_x = {}
    # Determine the set of labels for each unique image:
    for x,y in zip(xs,ys):
       orig_x[tuple(x.flatten())] = x
       mapping[tuple(x.flatten())].add(y)

    new_x = []
    new_y = []
    for flatten_x in mapping:
      x = orig_x[flatten_x]
      labels = mapping[flatten_x]
      if len(labels) == 1:
          new_x.append(x)
          new_y.append(next(iter(labels)))
      else:
          # Throw out images that match more than one label.
          pass

    num_uniq_3 = sum(1 for value in mapping.values() if len(value) == 1 and True in value)
    num_uniq_6 = sum(1 for value in mapping.values() if len(value) == 1 and False in value)
    num_uniq_both = sum(1 for value in mapping.values() if len(value) == 2)

    print("Number of unique images:", len(mapping.values()))
    print("Number of unique 3s: ", num_uniq_3)
    print("Number of unique 6s: ", num_uniq_6)
    print("Number of unique contradicting labels (both 3 and 6): ", num_uniq_both)
    print()
    print("Initial number of images: ", len(xs))
    print("Remaining non-contradicting unique images: ", len(new_x))

    return np.array(new_x), np.array(new_y)

Los conteos resultantes no coinciden con los valores informados, pero no se especifica el procedimiento exacto.

También vale la pena señalar aquí que la aplicación de filtrado de ejemplos contradictorios en este punto no evita totalmente que el modelo reciba ejemplos de entrenamiento contradictorios: el siguiente paso binariza los datos que causarán más colisiones.

x_train_nocon, y_train_nocon = remove_contradicting(x_train_small, y_train)
Number of unique images: 10387
Number of unique 3s:  4912
Number of unique 6s:  5426
Number of unique contradicting labels (both 3 and 6):  49

Initial number of images:  12049
Remaining non-contradicting unique images:  10338

1.4 Codifique los datos como circuitos cuánticos

Para procesar imágenes usando una computadora cuántica, Farhi et al. propuso representar cada píxel con un qubit, con el estado dependiendo del valor del píxel. El primer paso es convertir a una codificación binaria.

THRESHOLD = 0.5

x_train_bin = np.array(x_train_nocon > THRESHOLD, dtype=np.float32)
x_test_bin = np.array(x_test_small > THRESHOLD, dtype=np.float32)

Si eliminara las imágenes contradictorias en este punto, solo le quedarían 193, lo que probablemente no sea suficiente para un entrenamiento efectivo.

_ = remove_contradicting(x_train_bin, y_train_nocon)
Number of unique images: 193
Number of unique 3s:  80
Number of unique 6s:  69
Number of unique contradicting labels (both 3 and 6):  44

Initial number of images:  10338
Remaining non-contradicting unique images:  149

Los qubits en los índices de píxeles con valores que exceden un umbral, se rotan a través de una puerta \(X\) .

def convert_to_circuit(image):
    """Encode truncated classical image into quantum datapoint."""
    values = np.ndarray.flatten(image)
    qubits = cirq.GridQubit.rect(4, 4)
    circuit = cirq.Circuit()
    for i, value in enumerate(values):
        if value:
            circuit.append(cirq.X(qubits[i]))
    return circuit


x_train_circ = [convert_to_circuit(x) for x in x_train_bin]
x_test_circ = [convert_to_circuit(x) for x in x_test_bin]

Aquí está el circuito creado para el primer ejemplo (los diagramas de circuito no muestran qubits con puertas cero):

SVGCircuit(x_train_circ[0])
findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.

SVG

Compare este circuito con los índices donde el valor de la imagen supera el umbral:

bin_img = x_train_bin[0,:,:,0]
indices = np.array(np.where(bin_img)).T
indices
array([[2, 2],
       [3, 1]])

Convierta estos circuitos Cirq en tensores para tfq :

x_train_tfcirc = tfq.convert_to_tensor(x_train_circ)
x_test_tfcirc = tfq.convert_to_tensor(x_test_circ)

2. Red neuronal cuántica

Hay poca orientación para una estructura de circuito cuántico que clasifique imágenes. Dado que la clasificación se basa en la expectativa del qubit de lectura, Farhi et al. Proponga el uso de dos puertas qubit, con el qubit de lectura siempre actuando. Esto es similar en algunos aspectos a ejecutar un RNN unitario pequeño en los píxeles.

2.1 Construya el circuito modelo

El siguiente ejemplo muestra este enfoque en capas. Cada capa usa n instancias de la misma puerta, con cada uno de los qubits de datos actuando sobre el qubit de lectura.

Comience con una clase simple que agregará una capa de estas puertas a un circuito:

class CircuitLayerBuilder():
    def __init__(self, data_qubits, readout):
        self.data_qubits = data_qubits
        self.readout = readout

    def add_layer(self, circuit, gate, prefix):
        for i, qubit in enumerate(self.data_qubits):
            symbol = sympy.Symbol(prefix + '-' + str(i))
            circuit.append(gate(qubit, self.readout)**symbol)

Cree una capa de circuito de ejemplo para ver cómo se ve:

demo_builder = CircuitLayerBuilder(data_qubits = cirq.GridQubit.rect(4,1),
                                   readout=cirq.GridQubit(-1,-1))

circuit = cirq.Circuit()
demo_builder.add_layer(circuit, gate = cirq.XX, prefix='xx')
SVGCircuit(circuit)

SVG

Ahora construya un modelo de dos capas, que coincida con el tamaño del circuito de datos, e incluya las operaciones de preparación y lectura.

def create_quantum_model():
    """Create a QNN model circuit and readout operation to go along with it."""
    data_qubits = cirq.GridQubit.rect(4, 4)  # a 4x4 grid.
    readout = cirq.GridQubit(-1, -1)         # a single qubit at [-1,-1]
    circuit = cirq.Circuit()

    # Prepare the readout qubit.
    circuit.append(cirq.X(readout))
    circuit.append(cirq.H(readout))

    builder = CircuitLayerBuilder(
        data_qubits = data_qubits,
        readout=readout)

    # Then add layers (experiment by adding more).
    builder.add_layer(circuit, cirq.XX, "xx1")
    builder.add_layer(circuit, cirq.ZZ, "zz1")

    # Finally, prepare the readout qubit.
    circuit.append(cirq.H(readout))

    return circuit, cirq.Z(readout)
model_circuit, model_readout = create_quantum_model()

2.2 Envuelva el circuito modelo en un modelo tfq-keras

Construya el modelo de Keras con los componentes cuánticos. Este modelo se alimenta de los "datos cuánticos", de x_train_circ , que codifica los datos clásicos. Utiliza una capa de circuito cuántico parametrizado, tfq.layers.PQC , para entrenar el circuito modelo en los datos cuánticos.

Para clasificar estas imágenes, Farhi et al. propuso tomar la expectativa de un qubit de lectura en un circuito parametrizado. La expectativa devuelve un valor entre 1 y -1.

# Build the Keras model.
model = tf.keras.Sequential([
    # The input is the data-circuit, encoded as a tf.string
    tf.keras.layers.Input(shape=(), dtype=tf.string),
    # The PQC layer returns the expected value of the readout gate, range [-1,1].
    tfq.layers.PQC(model_circuit, model_readout),
])

A continuación, describa el procedimiento de entrenamiento del modelo, utilizando el método de compile .

Dado que la lectura esperada está en el rango [-1,1] , optimizar la pérdida de la bisagra es un ajuste algo natural.

Para usar la pérdida de bisagra aquí, debe hacer dos pequeños ajustes. Primero convierta las etiquetas, y_train_nocon , de booleano a [-1,1] , como se esperaba por la pérdida de bisagra.

y_train_hinge = 2.0*y_train_nocon-1.0
y_test_hinge = 2.0*y_test-1.0

En segundo lugar, use una métrica de precisión de bisagra personalizada que maneje correctamente [-1, 1] como el y_true hinge_accuracy tf.losses.BinaryAccuracy(threshold=0.0) espera que y_true sea ​​un valor booleano, por lo que no se puede usar con pérdida de bisagra).

def hinge_accuracy(y_true, y_pred):
    y_true = tf.squeeze(y_true) > 0.0
    y_pred = tf.squeeze(y_pred) > 0.0
    result = tf.cast(y_true == y_pred, tf.float32)

    return tf.reduce_mean(result)
model.compile(
    loss=tf.keras.losses.Hinge(),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[hinge_accuracy])
print(model.summary())
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 pqc (PQC)                   (None, 1)                 32        
                                                                 
=================================================================
Total params: 32
Trainable params: 32
Non-trainable params: 0
_________________________________________________________________
None

Entrenar el modelo cuántico

Ahora entrene al modelo; esto toma alrededor de 45 min. Si no quiere esperar tanto, use un pequeño subconjunto de datos (establezca NUM_EXAMPLES=500 , a continuación). Esto realmente no afecta el progreso del modelo durante el entrenamiento (solo tiene 32 parámetros y no necesita muchos datos para restringirlos). Usar menos ejemplos solo termina el entrenamiento antes (5 minutos), pero se ejecuta lo suficiente como para mostrar que está progresando en los registros de validación.

EPOCHS = 3
BATCH_SIZE = 32

NUM_EXAMPLES = len(x_train_tfcirc)
x_train_tfcirc_sub = x_train_tfcirc[:NUM_EXAMPLES]
y_train_hinge_sub = y_train_hinge[:NUM_EXAMPLES]

El entrenamiento de este modelo para la convergencia debería lograr una precisión >85 % en el conjunto de prueba.

qnn_history = model.fit(
      x_train_tfcirc_sub, y_train_hinge_sub,
      batch_size=32,
      epochs=EPOCHS,
      verbose=1,
      validation_data=(x_test_tfcirc, y_test_hinge))

qnn_results = model.evaluate(x_test_tfcirc, y_test)
Epoch 1/3
324/324 [==============================] - 68s 207ms/step - loss: 0.6745 - hinge_accuracy: 0.7719 - val_loss: 0.3959 - val_hinge_accuracy: 0.8004
Epoch 2/3
324/324 [==============================] - 68s 209ms/step - loss: 0.3964 - hinge_accuracy: 0.8291 - val_loss: 0.3498 - val_hinge_accuracy: 0.8997
Epoch 3/3
324/324 [==============================] - 66s 204ms/step - loss: 0.3599 - hinge_accuracy: 0.8854 - val_loss: 0.3395 - val_hinge_accuracy: 0.9042
62/62 [==============================] - 3s 41ms/step - loss: 0.3395 - hinge_accuracy: 0.9042

3. Red neuronal clásica

Si bien la red neuronal cuántica funciona para este problema MNIST simplificado, una red neuronal clásica básica puede superar fácilmente a una QNN en esta tarea. Después de una sola época, una red neuronal clásica puede lograr una precisión >98 % en el conjunto reservado.

En el siguiente ejemplo, se usa una red neuronal clásica para el problema de clasificación 3-6 usando la imagen completa de 28x28 en lugar de submuestrear la imagen. Esto converge fácilmente a casi el 100 % de precisión del conjunto de prueba.

def create_classical_model():
    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(32, [3, 3], activation='relu', input_shape=(28,28,1)))
    model.add(tf.keras.layers.Conv2D(64, [3, 3], activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Dropout(0.25))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(1))
    return model


model = create_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 conv2d_1 (Conv2D)           (None, 24, 24, 64)        18496     
                                                                 
 max_pooling2d (MaxPooling2D  (None, 12, 12, 64)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 12, 12, 64)        0         
                                                                 
 flatten (Flatten)           (None, 9216)              0         
                                                                 
 dense (Dense)               (None, 128)               1179776   
                                                                 
 dropout_1 (Dropout)         (None, 128)               0         
                                                                 
 dense_1 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 1,198,721
Trainable params: 1,198,721
Non-trainable params: 0
_________________________________________________________________
model.fit(x_train,
          y_train,
          batch_size=128,
          epochs=1,
          verbose=1,
          validation_data=(x_test, y_test))

cnn_results = model.evaluate(x_test, y_test)
95/95 [==============================] - 3s 31ms/step - loss: 0.0400 - accuracy: 0.9842 - val_loss: 0.0057 - val_accuracy: 0.9970
62/62 [==============================] - 0s 3ms/step - loss: 0.0057 - accuracy: 0.9970

El modelo anterior tiene casi 1,2 millones de parámetros. Para una comparación más justa, pruebe un modelo de 37 parámetros, en las imágenes submuestreadas:

def create_fair_classical_model():
    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten(input_shape=(4,4,1)))
    model.add(tf.keras.layers.Dense(2, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    return model


model = create_fair_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

model.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten_1 (Flatten)         (None, 16)                0         
                                                                 
 dense_2 (Dense)             (None, 2)                 34        
                                                                 
 dense_3 (Dense)             (None, 1)                 3         
                                                                 
=================================================================
Total params: 37
Trainable params: 37
Non-trainable params: 0
_________________________________________________________________
model.fit(x_train_bin,
          y_train_nocon,
          batch_size=128,
          epochs=20,
          verbose=2,
          validation_data=(x_test_bin, y_test))

fair_nn_results = model.evaluate(x_test_bin, y_test)
Epoch 1/20
81/81 - 1s - loss: 0.6678 - accuracy: 0.6546 - val_loss: 0.6326 - val_accuracy: 0.7358 - 503ms/epoch - 6ms/step
Epoch 2/20
81/81 - 0s - loss: 0.6186 - accuracy: 0.7654 - val_loss: 0.5787 - val_accuracy: 0.7515 - 98ms/epoch - 1ms/step
Epoch 3/20
81/81 - 0s - loss: 0.5629 - accuracy: 0.7861 - val_loss: 0.5247 - val_accuracy: 0.7764 - 104ms/epoch - 1ms/step
Epoch 4/20
81/81 - 0s - loss: 0.5150 - accuracy: 0.8301 - val_loss: 0.4825 - val_accuracy: 0.8196 - 103ms/epoch - 1ms/step
Epoch 5/20
81/81 - 0s - loss: 0.4762 - accuracy: 0.8493 - val_loss: 0.4490 - val_accuracy: 0.8293 - 97ms/epoch - 1ms/step
Epoch 6/20
81/81 - 0s - loss: 0.4438 - accuracy: 0.8527 - val_loss: 0.4216 - val_accuracy: 0.8298 - 99ms/epoch - 1ms/step
Epoch 7/20
81/81 - 0s - loss: 0.4169 - accuracy: 0.8555 - val_loss: 0.3986 - val_accuracy: 0.8313 - 98ms/epoch - 1ms/step
Epoch 8/20
81/81 - 0s - loss: 0.3951 - accuracy: 0.8595 - val_loss: 0.3794 - val_accuracy: 0.8313 - 105ms/epoch - 1ms/step
Epoch 9/20
81/81 - 0s - loss: 0.3773 - accuracy: 0.8596 - val_loss: 0.3635 - val_accuracy: 0.8328 - 98ms/epoch - 1ms/step
Epoch 10/20
81/81 - 0s - loss: 0.3620 - accuracy: 0.8611 - val_loss: 0.3499 - val_accuracy: 0.8333 - 97ms/epoch - 1ms/step
Epoch 11/20
81/81 - 0s - loss: 0.3488 - accuracy: 0.8714 - val_loss: 0.3382 - val_accuracy: 0.8720 - 98ms/epoch - 1ms/step
Epoch 12/20
81/81 - 0s - loss: 0.3372 - accuracy: 0.8831 - val_loss: 0.3279 - val_accuracy: 0.8720 - 95ms/epoch - 1ms/step
Epoch 13/20
81/81 - 0s - loss: 0.3271 - accuracy: 0.8831 - val_loss: 0.3187 - val_accuracy: 0.8725 - 97ms/epoch - 1ms/step
Epoch 14/20
81/81 - 0s - loss: 0.3181 - accuracy: 0.8832 - val_loss: 0.3107 - val_accuracy: 0.8725 - 96ms/epoch - 1ms/step
Epoch 15/20
81/81 - 0s - loss: 0.3101 - accuracy: 0.8833 - val_loss: 0.3035 - val_accuracy: 0.8725 - 96ms/epoch - 1ms/step
Epoch 16/20
81/81 - 0s - loss: 0.3030 - accuracy: 0.8833 - val_loss: 0.2972 - val_accuracy: 0.8725 - 105ms/epoch - 1ms/step
Epoch 17/20
81/81 - 0s - loss: 0.2966 - accuracy: 0.8833 - val_loss: 0.2913 - val_accuracy: 0.8725 - 104ms/epoch - 1ms/step
Epoch 18/20
81/81 - 0s - loss: 0.2908 - accuracy: 0.8928 - val_loss: 0.2861 - val_accuracy: 0.8725 - 104ms/epoch - 1ms/step
Epoch 19/20
81/81 - 0s - loss: 0.2856 - accuracy: 0.8955 - val_loss: 0.2816 - val_accuracy: 0.8725 - 99ms/epoch - 1ms/step
Epoch 20/20
81/81 - 0s - loss: 0.2809 - accuracy: 0.8952 - val_loss: 0.2773 - val_accuracy: 0.8725 - 101ms/epoch - 1ms/step
62/62 [==============================] - 0s 895us/step - loss: 0.2773 - accuracy: 0.8725

4. Comparación

La entrada de mayor resolución y un modelo más potente facilitan este problema para la CNN. Mientras que un modelo clásico de potencia similar (~32 parámetros) entrena con una precisión similar en una fracción del tiempo. De una forma u otra, la red neuronal clásica supera fácilmente a la red neuronal cuántica. Para los datos clásicos, es difícil vencer a una red neuronal clásica.

qnn_accuracy = qnn_results[1]
cnn_accuracy = cnn_results[1]
fair_nn_accuracy = fair_nn_results[1]

sns.barplot(["Quantum", "Classical, full", "Classical, fair"],
            [qnn_accuracy, cnn_accuracy, fair_nn_accuracy])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning
<AxesSubplot:>

png