GitHub - eaplatanios/tensorflow_scala: TensorFlow API for the Scala Programming Language (original) (raw)


CircleCI Codacy Badge License API Docs JNI Docs Data Docs Examples Docs

This library is a Scala API for https://www.tensorflow.org. It attempts to provide most of the functionality provided by the official Python API, while at the same type being strongly-typed and adding some new features. It is a work in progress and a project I started working on for my personal research purposes. Much of the API should be relatively stable by now, but things are still likely to change.

Chat Room

Please refer to the main website for documentation and tutorials. Here are a few useful links:

Citation

It would be greatly appreciated if you could cite this project using the following BibTex entry, if you end up using it in your work:

@misc{Platanios:2018:tensorflow-scala, title = {{TensorFlow Scala}}, author = {Platanios, Emmanouil Antonios}, howpublished = {\url{https://github.com/eaplatanios/tensorflow_scala}}, year = {2018} }

Main Features

// Create the MLP model.
val input = Input(FLOAT32, Shape(-1, 28, 28))
val trainInput = Input(INT64, Shape(-1))
val layer = FlattenFloat >>
Linear[Float]("Layer_0", 128) >> ReLU[Float]("Layer_0/Activation", 0.1f) >>
Linear[Float]("Layer_1", 64) >> ReLU[Float]("Layer_1/Activation", 0.1f) >>
Linear[Float]("Layer_2", 32) >> ReLU[Float]("Layer_2/Activation", 0.1f) >>
Linear[Float]("OutputLayer", 10)
val loss = SparseSoftmaxCrossEntropyFloat, Long, Float >>
Mean("Loss/Mean")
val optimizer = tf.train.GradientDescent(1e-6f)
val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)
// Create an estimator and train the model.
val estimator = InMemoryEstimator(model)
estimator.train(() => trainData, StopCriteria(maxSteps = Some(1000000)))
And by changing a few lines to the following code, you can get checkpoint capability, summaries, and seamless integration with TensorBoard:
val loss = SparseSoftmaxCrossEntropyFloat, Long, Float >>
Mean("Loss/Mean") >>
ScalarSummary(name = "Loss", tag = "Loss")
val summariesDir = Paths.get("/tmp/summaries")
val estimator = InMemoryEstimator(
modelFunction = model,
configurationBase = Configuration(Some(summariesDir)),
trainHooks = Set(
SummarySaver(summariesDir, StepHookTrigger(100)),
CheckpointSaver(summariesDir, StepHookTrigger(1000))),
tensorBoardConfig = TensorBoardConfig(summariesDir))
estimator.train(() => trainData, StopCriteria(maxSteps = Some(100000)))
If you now browse to https://127.0.0.1:6006 while training, you can see the training progress:
tensorboard_mnist_example_plot

Compiling from Source

Note that in order to compile TensorFlow Scala on your machine you will need to first install the TensorFlow Python API. You also need to make sure that you have apython3 alias for your python binary. This is used by CMake to find the TensorFlow header files in your installation.

Tutorials

Funding

Funding for the development of this library has been generously provided by the following sponsors:

cmu_logo nsf_logo afosr_logo
CMU Presidential Fellowship National Science Foundation Air Force Office of Scientific Research
awarded to Emmanouil Antonios Platanios Grant #: IIS1250956 Grant #: FA95501710218

TensorFlow, the TensorFlow logo, and any related marks are trademarks of Google Inc.