TensorFlow Java は、機械学習モデルの構築、トレーニング、デプロイのために任意の JVM 上で実行できます。グラフ モードまたはイーガー モードで CPU と GPU の両方の実行をサポートし、JVM 環境で TensorFlow を使用するための豊富な API を提供します。 Java および Scala や Kotlin などのその他の JVM 言語は、世界中の大企業から中小企業まで頻繁に使用されているため、TensorFlow Java は大規模な機械学習を導入するための戦略的な選択肢となっています。
要件
TensorFlow Java は Java 8 以降で実行され、すぐに使用できる次のプラットフォームをサポートします。
- Ubuntu 16.04以降。 64ビット、x86
- macOS 10.12.6 (Sierra) 以降。 64ビット、x86
- Windows 7 以降。 64ビット、x86
バージョン
TensorFlow Java には、 TensorFlow ランタイムから独立した独自のリリース サイクルがあります。したがって、そのバージョンは、それが実行される TensorFlow ランタイムのバージョンと一致しません。 TensorFlow Javaバージョン表を参照して、利用可能なすべてのバージョンと TensorFlow ランタイムとのマッピングをリストします。
アーティファクト
TensorFlow Java をプロジェクトに追加するには、いくつかの方法があります。最も簡単な方法はtensorflow-core-platform
アーティファクトに依存関係を追加することです。これには、TensorFlow Java Core API と、サポートされているすべてのプラットフォームで実行するために必要なネイティブ依存関係の両方が含まれます。
純粋な CPU バージョンの代わりに、次の拡張機能のいずれかを選択することもできます。
-
tensorflow-core-platform-mkl
: すべてのプラットフォームでのインテル® MKL-DNN のサポート tensorflow-core-platform-gpu
: Linux および Windows プラットフォームでの CUDA® のサポートtensorflow-core-platform-mkl-gpu
: Linux プラットフォームでの Intel® MKL-DNN および CUDA® のサポート。
さらに、 tensorflow-framework
ライブラリへの別の依存関係を追加すると、JVM 上の TensorFlow ベースの機械学習用の豊富なユーティリティ セットの恩恵を受けることができます。
Maven を使用したインストール
TensorFlow をMavenアプリケーションに含めるには、そのアーティファクトへの依存関係をプロジェクトのpom.xml
ファイルに追加します。例えば、
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
依存関係の数を減らす
tensorflow-core-platform
アーティファクトに依存関係を追加すると、サポートされているすべてのプラットフォームのネイティブ ライブラリがインポートされるため、プロジェクトのサイズが大幅に増加する可能性があることに注意することが重要です。
利用可能なプラットフォームのサブセットをターゲットにする場合は、 Maven 依存関係の除外機能を使用して、他のプラットフォームから不要なアーティファクトを除外できます。
アプリケーションに含めるプラットフォームを選択するもう 1 つの方法は、Maven コマンドラインまたはpom.xml
で JavaCPP システム プロパティを設定することです。詳細については、JavaCPP のドキュメントを参照してください。
スナップショットの使用
TensorFlow Java ソース リポジトリからの最新の TensorFlow Java 開発スナップショットは、 OSS Sonatype Nexus リポジトリで利用できます。これらのアーティファクトに依存するには、 pom.xml
で OSS スナップショット リポジトリを必ず構成してください。
<repositories>
<repository>
<id>tensorflow-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.4.0-SNAPSHOT</version>
</dependency>
</dependencies>
Gradle を使用したインストール
TensorFlow をGradleアプリケーションに含めるには、そのアーティファクトへの依存関係をプロジェクトのbuild.gradle
ファイルに追加します。例えば、
repositories {
mavenCentral()
}
dependencies {
compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.3'
}
依存関係の数を減らす
Gradle を使用して TensorFlow Java からネイティブ アーティファクトを除外することは、Maven の場合ほど簡単ではありません。この依存関係の数を減らすには、Gradle JavaCPP プラグインを使用することをお勧めします。
詳細については、Gradle JavaCPPドキュメントを参照してください。
ソースからのインストール
TensorFlow Java をソースからビルドし、場合によってはカスタマイズするには、次の手順をお読みください。
サンプルプログラム
この例では、TensorFlow を使用して Apache Maven プロジェクトを構築する方法を示します。まず、TensorFlow 依存関係をプロジェクトのpom.xml
ファイルに追加します。
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>hellotensorflow</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>HelloTensorFlow</exec.mainClass>
<!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<!-- Include TensorFlow (pure CPU only) for all supported platforms -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
</dependencies>
</project>
ソース ファイルsrc/main/java/HelloTensorFlow.java
を作成します。
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;
public class HelloTensorFlow {
public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
TInt32 x = TInt32.scalarOf(10);
Tensor dblX = dbl.call(x)) {
System.out.println(x.getInt() + " doubled is " + ((TInt32)dblX).getInt());
}
}
private static Signature dbl(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Add<TInt32> dblX = tf.math.add(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
}
コンパイルして実行します。
mvn -q compile exec:java
このコマンドは、TensorFlow のバージョンと簡単な計算を出力します。
成功! TensorFlow Java が設定されています。