Developers guides

Machine Learning Tutorial for Java Developers

What is Machine  Learning?

On a high level machine learning is a computer algorithm that can extract hidden rules from some given data. This special type of algorithm is able to automatically adjust its own internal parameters using sample data, in order to be able to estimate/predict something useful for similar data. The procedure of automatically adjusting internal parameters using sample data is called learning,  training or model building. Once the algorithm has finished the training, we get so called trained model. That is the model that can be used to predict something for some new, similar data. Sample data used for learning/training is called training set.

So what it can learn? Most commonly used type of machine learning algorithms called supervised learning, can learn to predict label or numeric value for a given input. You can think of it as learning from examples, where examples of predictions are given as a set of (input-prediction) pairs within the training set. This fundamentally changes the way we use computers, so instead of telling them how to do something, we’re showing them what to do by giving them examples.

What machine learning can do?

Specific predictive tasks performed by supervised machine learning models most often used in practice include:

  • Classification –  when you want to predict a category (a class label) for a given input. For example: predict whether a user is going to click on ad or not. It answers the question to which category something belongs, based on the given set of attributes. If desired prediction is yes/no or some qualitative value, then it is a classification task.
  •  Regression  – when you want to predict some continuous numeric value for a given inputs. For example, predict sales for the given marketing budget. It answers the question what will be the value of something, for the given value of something related. So if desired prediction is numeric value, then it is a regression task.

Note that there are also other types of tasks which are not included in this quick introductory overview, but general principles described here are the same.

Example Java Code for Machine Learning

Here is an example code that shows how to train neural network for classification task in Java using Feed Forward Neural Network from Deep Netts library. This type of machine learning model can be used to roughly predict if given email is spam or not. More details about the basic machine learning algorithms and feed forward network is explained in the following post For now take a look at the code for the minimal example for loading data from CSV file and training a machine learning model. Full source code for this example is available on Github.

// Load data from CSV file
DataSet trainingSet = DataSets.readCsv("spam.csv", inputsNum, outputsNum);

// Create a feed forward neural network using builder
FeedForwardNetwork neuralNet = FeedForwardNetwork.builder()
                               .addFullyConnectedLayer(15, ActivationType.RELU)
                               .addOutputLayer(outputsNum, ActivationType.SIGMOID)

// Train neural network

// use the trained model (neural network) for prediction
float[] prediction = neuralNet.getOutput();

Now when you’re familiar with basic concepts, take a look at the example code above to see how Deep Netts API provides easy and intuitive way to use machine learning in Java:

  • [Line 2]: Class DataSet holds the data that is used for training a machine learning algorithm – a training set.
  • [Line 2]: Class DataSets provides utility methods to work with data sets, and one of them is readCsv which loads data from CSV file and returns an instance of DataSet that is used as training set [line2]
  • [Lines 5-10]: Class FeedForwardNetwork provides widely used type of machine learning technique which can be used for both classification an regression problems. It provides a builder through static builder() method, that is used to specify various setting of a neural network.
  • [Line 12]: FeedForwardNetwork also provides the train method which performs the training/learning procedure on feed forward network using given trainingSet parameter.


Download Deep Netts for Free to get into machine learning in Java faster and easier.