Lingkaran pelatihan

Saat melatih model pembelajaran mesin, biasanya terdapat loop tempat data pelatihan diserap (atau dihasilkan), batch dijalankan melalui model, gradien diperoleh, dan model diperbarui melalui pengoptimal. Meskipun Anda dapat menulis loop pelatihan Anda sendiri untuk setiap aplikasi pelatihan, Swift untuk TensorFlow menyediakan abstraksi loop pelatihan eksperimental yang dapat menyederhanakan proses ini.

Modul TrainingLoop dalam repositori model berisi versi terbaru dari loop pelatihan umum eksperimental ini. Ini disusun sedemikian rupa untuk berintegrasi dengan pembungkus kumpulan data yang sesuai dengan Epochs API untuk memudahkan penyerapan data, dan untuk mengotomatiskan interaksi model, kumpulan data, dan pengoptimal dengan backend akselerator untuk mencapai kinerja optimal. Kustomisasi besar-besaran pada proses pelatihan dapat dicapai melalui penggunaan callback.

Sebagian besar contoh berbasis gambar di repositori model telah dikonversi untuk menggunakan abstraksi loop pelatihan ini, serta contoh pelatihan model teks yang diawasi. Namun, loop pelatihan mungkin tidak sesuai dengan desain saat ini untuk semua model pembelajaran mesin.

Implementasi loop pelatihan umum Swift untuk TensorFlow sangat dipengaruhi oleh Learner fastai . Untuk mengetahui lebih lanjut tentang desainnya, lihat "fastai: API Berlapis untuk Pembelajaran Mendalam" dan presentasi Sylvain Gugger "Fast.ai - Lingkaran pelatihan yang dapat disesuaikan tanpa batas" .

Penggunaan

Contoh ResNet-CIFAR10 memberikan demonstrasi yang baik tentang cara menggunakan loop pelatihan ini dalam praktik. Pertama, impor modul:

import TrainingLoop

lalu pilih backend akselerator dengan menyiapkan Device . Dalam hal ini, kami akan memilih backend berbasis X10 XLA dan menggunakan akselerator pertama yang tersedia:

let device = Device.defaultXLA

Langkah selanjutnya adalah mengonfigurasi kumpulan data, model, dan pengoptimal untuk digunakan dengan loop pelatihan Anda:

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

lalu atur loop pelatihan:

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

Perulangan pelatihan mengasumsikan bahwa kumpulan data yang Anda gunakan sesuai dengan Epochs API, dan memungkinkan Anda menentukan pemisahan mana dalam kumpulan data yang akan digunakan untuk pelatihan dan validasi. Fungsi kerugian apa pun dapat digunakan setelah ditempatkan ke dalam pembungkus yang kompatibel, seperti softmaxCrossEntropy is here .

Metrik saat ini yang dapat ditangkap meliputi:

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Terakhir, untuk melakukan pelatihan, Anda memanggil yang berikut ini:

try! trainingLoop.fit(&model, epochs: 10, on: device)

Ini akan melatih model selama 10 epoch menggunakan backend akselerator yang kami tentukan. Statistik akan ditampilkan selama pelatihan ke konsol menggunakan perintah animasi.

Panggilan balik

Penyesuaian loop pelatihan umum ini terjadi melalui penggunaan callback. Callback ini dapat dihubungkan ke berbagai titik dalam loop.

Beberapa callback bawaan menyediakan fungsionalitas yang dapat ditambahkan ke loop pelatihan apa pun. Ini termasuk:

  • Mencatat statistik ke file nilai yang dipisahkan koma (CSV).
  • Menyesuaikan kecepatan pembelajaran sesuai dengan jadwal khusus
  • Memantau dan membuat grafik kemajuan pelatihan melalui TensorBoard

Selain itu, Anda dapat membuat callback kustom Anda sendiri untuk menambahkan serangkaian fungsi tambahan ke loop pelatihan standar.

pencatatan CSV

Kelas CSVLogger merangkum panggilan balik yang akan menulis statistik pelatihan dalam format nilai yang dipisahkan koma ke file pilihan Anda. File ini akan dimulai dengan kolom berlabel epoch , batch , dan metrik apa pun yang telah Anda aktifkan dalam loop pelatihan Anda. Satu baris kemudian akan ditulis untuk setiap kumpulan, dengan nilai kolom tersebut saat ini.

Untuk menambahkan logging CSV ke loop pelatihan Anda, tambahkan sesuatu seperti berikut ke array callback yang disediakan untuk callbacks: parameter untuk TrainingLoop Anda :

try! CSVLogger(path: "file.csv").log

Sebagai contoh, sampel LeNet-MNIST menggunakan ini dalam loop pelatihannya.

Jadwal kecepatan pembelajaran

Hal yang biasa terjadi saat melatih model untuk mengubah kecepatan pembelajaran yang diberikan kepada pengoptimal selama proses pelatihan. Hal ini dapat berupa penurunan linier seiring berjalannya waktu, atau serumit siklus pemanasan dan penurunan yang dijelaskan oleh fungsi yang rumit.

Callback learningRateScheduler menyediakan sarana untuk mendeskripsikan jadwal kecepatan pembelajaran yang terdiri dari segmen berbeda, yang masing-masing memiliki bentuk berbeda. Hal ini dicapai dengan mendefinisikan LearningRateSchedule yang terdiri dari ScheduleSegment yang masing-masing memiliki Shape yang ditentukan oleh fungsi, kecepatan pembelajaran awal, dan kecepatan pembelajaran akhir.

Misalnya, sampel BERT-CoLA menggunakan peningkatan linier dalam kecepatan pemelajaran selama periode pemanasan dan penurunan linier setelahnya. Untuk melakukan hal ini, callback jadwal kecepatan pembelajaran didefinisikan sebagai berikut:

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

Kedua ScheduleSegment menentukan kecepatan pembelajaran yang dimulai dari 0 dan meningkat secara linear ke peakLearningRate melalui serangkaian 10 langkah terpisah, kemudian dimulai pada kecepatan pembelajaran akhir dari langkah sebelumnya dan menurun secara linear ke 0 pada akhir proses pelatihan.

Integrasi TensorBoard

TensorBoard adalah alat visualisasi yang ampuh untuk memantau pelatihan model, menganalisis pelatihan setelah selesai, atau membandingkan pelatihan yang berjalan. Swift untuk TensorFlow mendukung visualisasi TensorBoard melalui penggunaan modul TensorBoard di repositori model, yang menyediakan callback yang mencatat metrik pelatihan.

Contoh GPT2-WikiText2 mengilustrasikan cara menambahkan logging TensorBoard ke pelatihan model Anda. Pertama, impor modul TensorBoard . Maka itu semudah menambahkan tensorBoardStatisticsLogger() ke array callbacks: TrainingLoop Anda:.

Secara default, ini akan mencatat setiap pelatihan yang dijalankan dalam direktori run/tensorboard/stats . Untuk melihatnya dalam Tensorboard, jalankan

tensorboard --logdir ./run/tensorboard/stats

dan TensorBoard harus memulai server lokal tempat Anda dapat melihat metrik pelatihan Anda. Hasil pelatihan dan validasi harus ditampilkan secara terpisah, dan setiap proses memiliki stempel waktu unik untuk memudahkan perbandingan antara beberapa proses pada model yang sama.

Desain integrasi TensorBoard Swift untuk TensorFlow terinspirasi oleh tensorboardX . Callback TensorBoard secara langsung membuat buffer protokol peristiwa dan ringkasan yang sesuai dan menuliskannya dalam file log selama pelatihan.

Panggilan balik khusus

Selain callback bawaan yang dijelaskan di atas, Anda memiliki kemampuan untuk menyesuaikan fungsi loop pelatihan dengan membuat callback Anda sendiri. Callback ini adalah fungsi yang memiliki tanda tangan yang mirip dengan berikut ini:

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

Loop pelatihan dan status terkait diteruskan sebagai parameter pertama. Bagian perulangan saat ini yang ditanggapi oleh panggilan balik disediakan melalui event . Peristiwa perulangan pelatihan memiliki salah satu status berikut, masing-masing berkaitan dengan titik berbeda dalam siklus hidup perulangan:

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

Fungsi panggilan balik Anda dapat memilih untuk mengaktifkan logikanya pada kombinasi status di atas, yang memungkinkan untuk mengekstraksi data dari atau mengendalikan loop pelatihan dengan banyak cara.