Neste tutorial, você explorará um aplicativo da Web de exemplo que demonstra o aprendizado de transferência usando a API de camadas do TensorFlow.js. O exemplo carrega um modelo pré-treinado e treina novamente o modelo no navegador.
O modelo foi pré-treinado em Python nos dígitos 0-4 do conjunto de dados de classificação de dígitos MNIST . O retreinamento (ou aprendizado de transferência) no navegador usa os dígitos de 5 a 9. O exemplo mostra que as primeiras camadas de um modelo pré-treinado podem ser usadas para extrair recursos de novos dados durante o aprendizado de transferência, permitindo assim um treinamento mais rápido nos novos dados.
O aplicativo de exemplo para este tutorial está disponível online , então você não precisa baixar nenhum código ou configurar um ambiente de desenvolvimento. Se quiser executar o código localmente, conclua as etapas opcionais em Executar o exemplo localmente . Se você não deseja configurar um ambiente de desenvolvimento, pode pular para Explorar o exemplo .
O código de exemplo está disponível no GitHub .
(Opcional) Execute o exemplo localmente
Pré-requisitos
Para executar o aplicativo de exemplo localmente, você precisa do seguinte instalado em seu ambiente de desenvolvimento:
Instale e execute o aplicativo de exemplo
- Clone ou baixe o repositório
tfjs-examples
. Mude para o diretório
mnist-transfer-cnn
:cd tfjs-examples/mnist-transfer-cnn
Instalar dependências:
yarn
Inicie o servidor de desenvolvimento:
yarn run watch
Explorar o exemplo
Abra o aplicativo de exemplo . (Ou, se estiver executando o exemplo localmente, acesse http://localhost:1234
em seu navegador.)
Você deve ver uma página intitulada MNIST CNN Transfer Learning . Siga as instruções para experimentar o aplicativo.
Aqui estão algumas coisas para tentar:
- Experimente os diferentes modos de treinamento e compare a perda e a precisão.
- Selecione diferentes exemplos de bitmap e inspecione as probabilidades de classificação. Observe que os números em cada exemplo de bitmap são valores inteiros em tons de cinza que representam pixels de uma imagem.
- Edite os valores inteiros do bitmap e veja como as alterações afetam as probabilidades de classificação.
Explorar o código
O aplicativo da Web de exemplo carrega um modelo que foi pré-treinado em um subconjunto do conjunto de dados MNIST. O pré-treinamento é definido em um programa Python: mnist_transfer_cnn.py
. O programa Python está fora do escopo deste tutorial, mas vale a pena dar uma olhada se você quiser ver um exemplo de conversão de modelo .
O arquivo index.js
contém a maior parte do código de treinamento da demonstração. Quando index.js
é executado no navegador, uma função de configuração, setupMnistTransferCNN
, instancia e inicializa MnistTransferCNNPredictor
, que encapsula as rotinas de retreinamento e previsão.
O método de inicialização, MnistTransferCNNPredictor.init
, carrega um modelo, carrega dados de retreinamento e cria dados de teste. Aqui está a linha que carrega o modelo:
this.model = await loader.loadHostedPretrainedModel(urls.model);
Se você observar a definição de loader.loadHostedPretrainedModel
, verá que ele retorna o resultado de uma chamada para tf.loadLayersModel
. Esta é a API do TensorFlow.js para carregar um modelo composto por objetos Layer.
A lógica de retreinamento é definida em MnistTransferCNNPredictor.retrainModel
. Se o usuário selecionou Congelar camadas de feição como o modo de treinamento, as 7 primeiras camadas do modelo base são congeladas e apenas as 5 camadas finais são treinadas em novos dados. Se o usuário tiver selecionado Reinitialize weights , todos os pesos serão redefinidos e o aplicativo efetivamente treinará o modelo a partir do zero.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
O modelo é então compilado e treinado nos dados de teste usando model.fit()
:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
Para saber mais sobre os parâmetros model.fit()
, consulte a documentação da API .
Após ser treinado no novo conjunto de dados (dígitos 5-9), o modelo pode ser usado para fazer previsões. O método MnistTransferCNNPredictor.predict
faz isso usando model.predict()
:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
Observe o uso de tf.tidy
, que ajuda a evitar vazamentos de memória.
Saber mais
Este tutorial explorou um aplicativo de exemplo que executa o aprendizado de transferência no navegador usando o TensorFlow.js. Confira os recursos abaixo para saber mais sobre modelos pré-treinados e aprendizado de transferência.
TensorFlow.js
- Importando um modelo Keras para o TensorFlow.js
- Importar um modelo do TensorFlow para o TensorFlow.js
- Modelos pré-fabricados para TensorFlow.js
Núcleo do TensorFlow