This tutorial walks you through building state of the art AI based on deep learning for image recognition using Java. It will help you to get started with modern AI development using only your Java skills.
The example used in the article is recognition of handwritten digits which is commonly used as a hello world for deep learning.
Full source code for running this project is available on github and you can use it as a starter for your own projects for image recognition.
In order to run the example you need Deep Netts, which is available for free download here.
A step by step guide for setting up Deep Netts is available here.
Our data set is 60000 images of handwritten digits. Each image is size of 28×28 pixels. Few sample images are shown in the image below.
Example images used to train a deep learning model are located in folder mnist/training. This folder contains ten sub-folders named 0–9 where each folder contains images of a specific digit. In this example we are going to use a subset of randomly chosen 1000 images for simplicity reasons. Folder mnist contains files index.txt which contains a list of subset of images to use for training, and labels.txt which contain image category labels — in this case digits 0–9. The images will be automatically downloaded and unpacked once you clone and run the example from GitHub.
The model training procedure consist of iteratively presenting example images to a deep learning model (in this case a convolutional neural network) during which it performs an automated tuning of it’s internal parameters in order to lower the output error and increase the recognition accuracy.This procedure of automated tuning compares image labels from training data to actual outputs / predictions of a neural network to calculate the output error, and uses optimization technique to find the minimal error. Each output of the convolutional neural network corresponds to a single digit, and represents a probability that the input image belongs to the corresponding digit.
The architecture of convolutional neural networks shown on image below. Typically a convolutional neural network consists of a stack of processing blocks called layers. Convolutional layers perform pattern detection, pooling layers downsizes its inputs, and fully connected layers perform classification. The number and size of layers depends on the problem, and is usually determined experimentally. A simple architecture of a small convolutional neural network is shown in the image below. More details about how all this works can be found in this tutorial.
The following code segment creates an instance of convolutional neural network using its builder. More details about all settings used in builder are available in apidocs.
ConvolutionalNetwork neuralNet = ConvolutionalNetwork.builder() .addInputLayer(imageWidth, imageHeight) .addConvolutionalLayer(12, 5) .addMaxPoolingLayer(2, 2) .addFullyConnectedLayer(60) .addOutputLayer(labelsCount, ActivationType.SOFTMAX) .hiddenActivationFunction(ActivationType.RELU) .lossFunction(LossType.CROSS_ENTROPY)
The code segment below sets the basic training parameters and starts the training procedure. More details about various settings for the Backpropagation algorithm used to train convolutional neural networks are available in apidocs.
BackpropagationTrainer trainer = neuralNet.getTrainer(); // get trainer from the neural net set training parameters below trainer.setLearningRate(0.001f) // a percent of error that is used for tuning internal parameters .setMaxError(0.05f) // stop the training when the specified error threshold is reached .setMaxEpochs(1000); // stop the training when maximum training iterations/epochs is reached trainer.train(trainingSet); // run the training with the specified training set
Once the training starts it will log information about every training iteration, also called epoch as shown below. This information includes prediction error and accuracy during the training process.
Downloading and/or unpacking MNIST training set to: D:\DeepNettsProjects\GetStartedWithDeepLearningInJava\mnist - this may take a while ( 44.9 MB )! Downloaded MNIST data set to mnist Loading images... Loaded 10 labels Loaded 1000 images Splitting data set: [0.65, 0.35] Creating neural network architecture... Training the neural network ------------------------------------------------------------------------ TRAINING NEURAL NETWORK ------------------------------------------------------------------------ Initial Train Error:2.445505 Epoch:1, Time:1893ms, TrainError:1.8640603, TrainErrorChange:-0.5814446, TrainAccuracy: 0.09453033 Epoch:2, Time:1718ms, TrainError:1.1450188, TrainErrorChange:-0.71904147, TrainAccuracy: 0.42761928 Epoch:3, Time:1728ms, TrainError:0.81875384, TrainErrorChange:-0.32626498, TrainAccuracy: 0.6140676 Epoch:4, Time:1820ms, TrainError:0.64733726, TrainErrorChange:-0.17141658, TrainAccuracy: 0.69346887 Epoch:5, Time:1962ms, TrainError:0.5406312, TrainErrorChange:-0.10670608, TrainAccuracy: 0.7378232 ... ... Epoch:42, Time:1379ms, TrainError:0.05284841, TrainErrorChange:-0.0019328743, TrainAccuracy: 0.99852943 Epoch:43, Time:1432ms, TrainError:0.050998032, TrainErrorChange:-0.0018503778, TrainAccuracy: 0.99852943 Epoch:44, Time:1423ms, TrainError:0.049276747, TrainErrorChange:-0.0017212853, TrainAccuracy: 1.0 TRAINING COMPLETED Total Training Time: 131212ms
The graph below shows how error on network output is lowering while the prediction accuracy is growing during the training.
Also note that Deep Netts provides a deep learning IDE that simplifies procedure of building and understanding deep learning models using wizards and visual tools.
Model testing (or evaluation) is performed in order to check how well the trained model will perform on new data — examples that it has not seen during the training. The model testing is performed in Deep Netts with one call to test() method which returns various classification metrics that help to understand the quality of predictions.
// Test/evaluate trained network to see how it perfroms with unseen data - the test set EvaluationMetrics em = neuralNet.test(testSet);
The testing/evaluation returns the following results:
Classification metrics Class: Macro Average Total items: 368 True positive:298.0 Number of examples correctly classified as positive True negative:0.0 Number of examples correctly classified as negative False positive:35.0 Number of examples incorrectly classified as positive False negative:35.0 Number of examples incorrectly classified as negative Accuracy (ACC): 0.8097826 How often is classifier correct in total (percent of correct classifications) Precision (PPV): 0.8948949 How often is classifier correct when it gives positive prediction Recall: 0.8948949When it is actually positive class, how often does it give positive prediction F1 Score: 0.8948949 Harmonic average (balance) of precision and recall False discovery rate (FDR): 0.1051051 Matthews correlation Coefficient (MCC): -0.10510510549197885
The tricky part here is understanding the various evaluation metrics, and Deep Netts helps here by providing concise explanations for each metric.
The code below shows how to use a trained model for prediction. First we load the image file into ExampleImage class which provides a format (Tensor) that can be used as input for the predict method of the neural network.
The predict method returns a tensor (basically an array), which contains the probabilities for all digits. The digit with highest probability is most likely the one in the input image.
ExampleImage someImage = new ExampleImage(ImageIO.read(new File("mnist/training/9/00019.png"))); // load some image from file someImage.invert(); // used in this example/data set in order to focus on black images and not white background Tensor predictions = neuralNet.predict(someImage.getInput()); // get prediction for the specified image int maxIdx = indexOfMax(predictions); // get index of prediction with the highest probability LOGGER.info(predictions); LOGGER.info("Image label with highest probability:"+neuralNet.getOutputLabel(maxIdx));
Also note that trained model can be saved , in order to load it and use it later in your app.
Full source of the example
Free Deep Netts download
Step by step guide for installing Deep Netts