Use um modelo pré-treinado

Neste tutorial, você explorará um aplicativo da Web de exemplo que demonstra o aprendizado de transferência usando a API de camadas do TensorFlow.js. O exemplo carrega um modelo pré-treinado e treina novamente o modelo no navegador.

O modelo foi pré-treinado em Python nos dígitos 0-4 do conjunto de dados de classificação de dígitos MNIST . O retreinamento (ou aprendizado de transferência) no navegador usa os dígitos de 5 a 9. O exemplo mostra que as primeiras camadas de um modelo pré-treinado podem ser usadas para extrair recursos de novos dados durante o aprendizado de transferência, permitindo assim um treinamento mais rápido nos novos dados.

O aplicativo de exemplo para este tutorial está disponível online , então você não precisa baixar nenhum código ou configurar um ambiente de desenvolvimento. Se quiser executar o código localmente, conclua as etapas opcionais em Executar o exemplo localmente . Se você não deseja configurar um ambiente de desenvolvimento, pode pular para Explorar o exemplo .

O código de exemplo está disponível no GitHub .

(Opcional) Execute o exemplo localmente

Pré-requisitos

Para executar o aplicativo de exemplo localmente, você precisa do seguinte instalado em seu ambiente de desenvolvimento:

Instale e execute o aplicativo de exemplo

  1. Clone ou baixe o repositório tfjs-examples .
  2. Mude para o diretório mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Instalar dependências:

    yarn
    
  4. Inicie o servidor de desenvolvimento:

    yarn run watch
    

Explorar o exemplo

Abra o aplicativo de exemplo . (Ou, se estiver executando o exemplo localmente, acesse http://localhost:1234 em seu navegador.)

Você deve ver uma página intitulada MNIST CNN Transfer Learning . Siga as instruções para experimentar o aplicativo.

Aqui estão algumas coisas para tentar:

  • Experimente os diferentes modos de treinamento e compare a perda e a precisão.
  • Selecione diferentes exemplos de bitmap e inspecione as probabilidades de classificação. Observe que os números em cada exemplo de bitmap são valores inteiros em tons de cinza que representam pixels de uma imagem.
  • Edite os valores inteiros do bitmap e veja como as alterações afetam as probabilidades de classificação.

Explorar o código

O aplicativo da Web de exemplo carrega um modelo que foi pré-treinado em um subconjunto do conjunto de dados MNIST. O pré-treinamento é definido em um programa Python: mnist_transfer_cnn.py . O programa Python está fora do escopo deste tutorial, mas vale a pena dar uma olhada se você quiser ver um exemplo de conversão de modelo .

O arquivo index.js contém a maior parte do código de treinamento da demonstração. Quando index.js é executado no navegador, uma função de configuração, setupMnistTransferCNN , instancia e inicializa MnistTransferCNNPredictor , que encapsula as rotinas de retreinamento e previsão.

O método de inicialização, MnistTransferCNNPredictor.init , carrega um modelo, carrega dados de retreinamento e cria dados de teste. Aqui está a linha que carrega o modelo:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Se você observar a definição de loader.loadHostedPretrainedModel , verá que ele retorna o resultado de uma chamada para tf.loadLayersModel . Esta é a API do TensorFlow.js para carregar um modelo composto por objetos Layer.

A lógica de retreinamento é definida em MnistTransferCNNPredictor.retrainModel . Se o usuário selecionou Congelar camadas de feição como o modo de treinamento, as 7 primeiras camadas do modelo base são congeladas e apenas as 5 camadas finais são treinadas em novos dados. Se o usuário tiver selecionado Reinitialize weights , todos os pesos serão redefinidos e o aplicativo efetivamente treinará o modelo a partir do zero.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

O modelo é então compilado e treinado nos dados de teste usando model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Para saber mais sobre os parâmetros model.fit() , consulte a documentação da API .

Após ser treinado no novo conjunto de dados (dígitos 5-9), o modelo pode ser usado para fazer previsões. O método MnistTransferCNNPredictor.predict faz isso usando model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Observe o uso de tf.tidy , que ajuda a evitar vazamentos de memória.

Saber mais

Este tutorial explorou um aplicativo de exemplo que executa o aprendizado de transferência no navegador usando o TensorFlow.js. Confira os recursos abaixo para saber mais sobre modelos pré-treinados e aprendizado de transferência.

TensorFlow.js

Núcleo do TensorFlow