Using Deep Learning in Java for Image Recognition

This tutorial walks you through building state of the art AI based on deep learning for image recognition using Java. It will help you to get started with modern AI development using only your Java skills.

The example used in the article is recognition of handwritten digits which is commonly used as a hello world for deep learning.

Full source code for running this project is available on github and you can use it as a starter for your own projects for image recognition.

In order to run the example you need Deep Netts, which is available for free download here.
A step by step guide for setting up Deep Netts is available here.

Data Set

Our data set is 60000 images of handwritten digits. Each image is size of 28×28 pixels. Few sample images are shown in the image below.

Sample images from the data set

Example images used to train a deep learning model are located in folder mnist/training. This folder contains ten sub-folders named 0–9 where each folder contains images of a specific digit. In this example we are going to use a subset of randomly chosen 1000 images for simplicity reasons. Folder mnist contains files index.txt which contains a list of subset of images to use for training, and labels.txt which contain image category labels — in this case digits 0–9. The images will be automatically downloaded and unpacked once you clone and run the example from GitHub.

Model Training

The model training procedure consist of iteratively presenting example images to a deep learning model (in this case a convolutional neural network) during which it performs an automated tuning of it’s internal parameters in order to lower the output error and increase the recognition accuracy.This procedure of automated tuning compares image labels from training data to actual outputs / predictions of a neural network to calculate the output error, and uses optimization technique to find the minimal error. Each output of the convolutional neural network corresponds to a single digit, and represents a probability that the input image belongs to the corresponding digit.

The architecture of convolutional neural networks shown on image below. Typically a convolutional neural network consists of a stack of processing blocks called layers. Convolutional layers perform pattern detection, pooling layers downsizes its inputs, and fully connected layers perform classification. The number and size of layers depends on the problem, and is usually determined experimentally. A simple architecture of a small convolutional neural network is shown in the image below. More details about how all this works can be found in this tutorial.

Architecture of a convolutional neural network

The following code segment creates an instance of convolutional neural network using its builder. More details about all settings used in builder are available in apidocs.

ConvolutionalNetwork neuralNet = ConvolutionalNetwork.builder()
                                   .addInputLayer(imageWidth, imageHeight)
                                   .addConvolutionalLayer(12, 5)
                                   .addMaxPoolingLayer(2, 2)     
                                   .addOutputLayer(labelsCount, ActivationType.SOFTMAX)

The code segment below sets the basic training parameters and starts the training procedure. More details about various settings for the Backpropagation algorithm used to train convolutional neural networks are available in apidocs.

BackpropagationTrainer trainer = neuralNet.getTrainer(); // get trainer from the neural net set training parameters below
trainer.setLearningRate(0.001f) // a percent of error that is used for tuning internal parameters
       .setMaxError(0.05f)    // stop the training when the specified error threshold is reached
       .setMaxEpochs(1000);  // stop the training when  maximum training iterations/epochs is reached
trainer.train(trainingSet);  // run the training with the specified training set

Once the training starts it will log information about every training iteration, also called epoch as shown below. This information includes prediction error and accuracy during the training process.

Downloading and/or unpacking MNIST training set to: D:\DeepNettsProjects\GetStartedWithDeepLearningInJava\mnist - this may take a while ( 44.9 MB )!
Downloaded MNIST data set to mnist
Loading images...
Loaded 10 labels
Loaded 1000 images
Splitting data set: [0.65, 0.35]
Creating neural network architecture...
Training the neural network
Initial Train Error:2.445505
Epoch:1, Time:1893ms, TrainError:1.8640603, TrainErrorChange:-0.5814446, TrainAccuracy: 0.09453033
Epoch:2, Time:1718ms, TrainError:1.1450188, TrainErrorChange:-0.71904147, TrainAccuracy: 0.42761928
Epoch:3, Time:1728ms, TrainError:0.81875384, TrainErrorChange:-0.32626498, TrainAccuracy: 0.6140676
Epoch:4, Time:1820ms, TrainError:0.64733726, TrainErrorChange:-0.17141658, TrainAccuracy: 0.69346887
Epoch:5, Time:1962ms, TrainError:0.5406312, TrainErrorChange:-0.10670608, TrainAccuracy: 0.7378232
Epoch:42, Time:1379ms, TrainError:0.05284841, TrainErrorChange:-0.0019328743, TrainAccuracy: 0.99852943
Epoch:43, Time:1432ms, TrainError:0.050998032, TrainErrorChange:-0.0018503778, TrainAccuracy: 0.99852943
Epoch:44, Time:1423ms, TrainError:0.049276747, TrainErrorChange:-0.0017212853, TrainAccuracy: 1.0

Total Training Time: 131212ms

The graph below shows how error on network output is lowering while the prediction accuracy is growing during the training.

Also note that Deep Netts provides a deep learning IDE that simplifies procedure of building and understanding deep learning models using wizards and visual tools.

Model testing (or evaluation) is performed in order to check how well the trained model will perform on new data — examples that it has not seen during the training. The model testing is performed in Deep Netts with one call to test() method which returns various classification metrics that help to understand the quality of predictions.

// Test/evaluate trained network to see how it perfroms with unseen data - the test set
EvaluationMetrics em = neuralNet.test(testSet);

The testing/evaluation returns the following results:

Classification metrics
Class: Macro Average
Total items: 368
True positive:298.0 Number of examples correctly classified as positive
True negative:0.0 Number of examples correctly classified as negative
False positive:35.0 Number of examples incorrectly classified as positive
False negative:35.0 Number of examples incorrectly classified as negative
Accuracy (ACC): 0.8097826 How often is classifier correct in total (percent of correct classifications)
Precision (PPV): 0.8948949 How often is classifier correct when it gives positive prediction
Recall: 0.8948949When it is actually positive class, how often does it give positive prediction
F1 Score: 0.8948949 Harmonic average (balance) of precision and recall
False discovery rate (FDR): 0.1051051
Matthews correlation Coefficient (MCC): -0.10510510549197885

The tricky part here is understanding the various evaluation metrics, and Deep Netts helps here by providing concise explanations for each metric.

Model Usage

The code below shows how to use a trained model for prediction. First we load the image file into ExampleImage class which provides a format (Tensor) that can be used as input for the predict method of the neural network.

The predict method returns a tensor (basically an array), which contains the probabilities for all digits. The digit with highest probability is most likely the one in the input image.

ExampleImage someImage = new ExampleImage( File("mnist/training/9/00019.png"))); // load some image from file
someImage.invert(); // used in this example/data set in order to focus on black images and not white background
Tensor predictions = neuralNet.predict(someImage.getInput()); // get prediction for the specified image
int maxIdx = indexOfMax(predictions); // get index of prediction with the highest probability;"Image label with highest probability:"+neuralNet.getOutputLabel(maxIdx));

Also note that trained model can be saved , in order to load it and use it later in your app.


Full source of the example

Free Deep Netts download

Step by step guide for installing Deep Netts

Scroll to Top