機械学習モデルをトレーニングする場合、トレーニング データが取り込まれ (または生成され)、モデル内でバッチが実行され、勾配が取得され、オプティマイザーを介してモデルが更新されるループが一般的です。トレーニング アプリケーションごとに独自のトレーニング ループを作成できますが、Swift for TensorFlow は、このプロセスを簡素化できる実験的なトレーニング ループの抽象化を提供します。
モデル リポジトリ内のTrainingLoop
モジュールには、この実験的な一般化トレーニング ループの現在のバージョンが含まれています。これは、Epochs API に準拠したデータセット ラッパーと統合してデータの取り込みを容易にし、モデル、データセット、オプティマイザーとアクセラレータ バックエンドの相互作用を自動化して最適なパフォーマンスを実現するような方法で構造化されています。トレーニング プロセスの大幅なカスタマイズは、コールバックを使用して実現できます。
モデル リポジトリ内のほとんどの画像ベースのサンプルは、教師ありテキスト モデルのトレーニング サンプルと同様に、このトレーニング ループの抽象化を使用するように変換されています。ただし、トレーニング ループは、現在の設計ではすべての機械学習モデルに適切ではない可能性があります。
TensorFlow の一般化されたトレーニング ループ用の Swift の実装は、fastai の Learnerに大きく影響されています。設計の詳細については、 「fastai: 深層学習のための階層化された API」および Sylvain Gugger のプレゼンテーション「Fast.ai - 無限にカスタマイズ可能なトレーニング ループ」を参照してください。
使用法
ResNet-CIFAR10 の例は、このトレーニング ループを実際に使用する方法をうまく示しています。まず、モジュールをインポートします。
import TrainingLoop
次に、 Device
を設定してアクセラレータ バックエンドを選択します。この場合、X10 XLA ベースのバックエンドを選択し、最初に利用可能なアクセラレータを使用します。
let device = Device.defaultXLA
次のステップでは、トレーニング ループで使用するデータセット、モデル、オプティマイザーを構成します。
let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)
次に、トレーニング ループを設定します。
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
metrics: [.accuracy])
トレーニング ループでは、使用しているデータセットが Epochs API に準拠していることを前提としており、データセット内のどの分割をトレーニングと検証に使用するかを指定できます。どの損失関数も、 softmaxCrossEntropy
などの互換性のあるラッパーに配置すると使用できます。
取得できる現在のメトリクスには次のものがあります。
-
loss
-
accuracy
-
top5Accuracy
-
matthewsCorrelationCoefficient
-
perplexity
最後に、トレーニングを実行するには、次を呼び出します。
try! trainingLoop.fit(&model, epochs: 10, on: device)
これにより、指定したアクセラレータ バックエンドを使用して 10 エポックの間モデルがトレーニングされます。統計は、トレーニング中にアニメーション プロンプトを使用してコンソールに表示されます。
コールバック
この一般化されたトレーニング ループのカスタマイズは、コールバックを使用して行われます。これらのコールバックは、ループ内のさまざまなポイントにフックできます。
いくつかの組み込みコールバックは、任意のトレーニング ループに追加できる機能を提供します。これらには次のものが含まれます。
- コンマ区切り値 (CSV) ファイルへの統計のログ記録
- カスタム スケジュールに従って学習率を調整する
- TensorBoard を介したトレーニングの進行状況のモニタリングとグラフ化
これらに加えて、独自のカスタム コールバックを作成して、標準のトレーニング ループにさまざまな機能を追加できます。
CSVロギング
CSVLogger
クラスは、トレーニング統計をカンマ区切り値形式で選択したファイルに書き出すコールバックをカプセル化します。このファイルは、 epoch
、 batch
、およびトレーニング ループ内で有効にしたメトリクスのラベルが付けられた列で始まります。その後、バッチごとに 1 行が書き込まれ、それらの列の現在の値が書き込まれます。
CSV ログをトレーニング ループに追加するには、 TrainingLoop
のcallbacks:
パラメーターに提供されるコールバックの配列に次のようなものを追加します。
try! CSVLogger(path: "file.csv").log
例として、 LeNet-MNIST
サンプルはトレーニング ループ内でこれを使用しています。
学習率スケジュール
モデルをトレーニングする場合、トレーニング プロセス中にオプティマイザーに提供される学習率を変更するのが一般的です。これは、時間の経過とともに直線的に減少するような単純なものもあれば、複雑な関数で記述されるウォームアップと衰退サイクルのように複雑なものもあります。
learningRateScheduler
コールバックは、それぞれが独自の形状を持つさまざまなセグメントで構成される学習率スケジュールを記述する手段を提供します。これは、関数によって定義されたShape
、初期学習率、および最終学習率を持つScheduleSegment
で構成されるLearningRateSchedule
を定義することによって実現されます。
たとえば、 BERT-CoLA サンプルでは、ウォームアップ期間中の学習率の線形増加とその後の線形減少を使用しています。これを行うには、学習率スケジュール コールバックを次のように定義します。
learningRateScheduler(
schedule: makeSchedule(
[
ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
ScheduleSegment(shape: linear, endRate: 0)
]
)
)
2 つのScheduleSegment
は、0 から始まり、一連の 10 個の離散ステップにわたってpeakLearningRate
まで直線的に増加し、その後、前のステップの最終学習率で開始し、トレーニング プロセスの終了までに 0 まで直線的に減少する学習率を定義します。
TensorBoard の統合
TensorBoard は、モデル トレーニングのモニタリング、完了時のトレーニングの分析、またはトレーニング実行の比較を行うための強力な視覚化ツールです。 Swift for TensorFlow は、モデル リポジトリのTensorBoard
モジュールを使用して TensorBoard 視覚化をサポートし、トレーニング メトリクスを記録するコールバックを提供します。
GPT2-WikiText2サンプルは、モデル トレーニングに TensorBoard ログを追加する方法を示しています。まず、 TensorBoard
モジュールをインポートします。あとは、 tensorBoardStatisticsLogger()
をTrainingLoop
のcallbacks:
array に追加するだけです。
デフォルトでは、各トレーニングの実行がrun/tensorboard/stats
ディレクトリ内に記録されます。 Tensorboard 内でこれを表示するには、次を実行します。
tensorboard --logdir ./run/tensorboard/stats
TensorBoard は、トレーニング メトリクスを表示できるローカル サーバーを起動する必要があります。トレーニングと検証の結果は個別に表示する必要があり、各実行には一意のタイムスタンプが付いているため、同じモデルの複数の実行を簡単に比較できます。
Swift for TensorFlow TensorBoard 統合の設計は、 tensorboardXからインスピレーションを受けています。 TensorBoard コールバックは、適切なイベントおよび要約プロトコル バッファーを直接作成し、トレーニング中にログ ファイル内にそれらを書き込みます。
カスタムコールバック
上記の組み込みコールバックに加えて、独自のコールバックを作成してトレーニング ループの機能をカスタマイズできます。これらのコールバックは、次のようなシグネチャを持つ関数です。
func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
if event == .updateStart {
...
}
}
トレーニング ループと関連する状態が最初のパラメーターとして渡されます。コールバックが応答しているループの現在の部分は、 event
介して提供されます。トレーニング ループ イベントには次のいずれかの状態があり、それぞれがループのライフ サイクルの異なる時点に対応します。
-
fitStart
-
fitEnd
-
epochStart
-
epochEnd
-
trainingStart
-
trainingEnd
-
validationStart
-
validationEnd
-
batchStart
-
batchEnd
-
updateStart
-
inferencePredictionEnd
コールバック関数は、上記の状態の任意の組み合わせでロジックをアクティブにすることを選択でき、これにより、さまざまな方法でトレーニング ループからデータを抽出したり、トレーニング ループを制御したりすることができます。