Machine Learning Tutorial for Java Developers

What is Machine  Learning?

Machine Learning is a computer algorithm that is able to adjust its own internal parameters using sample data, in order to perform some predictive task (try to guess something useful) on similar data.

The procedure of automatically adjusting internal parameters using sample data is called learning,  training or model building. Once the algorithm has finished the training, we get so called trained model. This model than can be used to perform predictive tasks on some new, similar data.
Sample data used for learning/training is called training set.
So what it can learn? One commonly used type of machine learning algorithms called supervised learning, can learn the relations between inputs and outputs  specified in the training data. You can also think of it, as a method for roughly figuring out unknown function based on the given inputs and outputs, in order to use it for prediction. This fundamentally changes the way we use computers, so instead of telling them how to do something, we’re showing them what to do [1].

Machine Learning as a black box
Machine Learning as a black-box

What it can do?

Predictive tasks performed by machine learning models most often used in practice include:

  • Classification –  tries to predict a category (a class) of some input item. For example: predict wheather an  email is spam or not spam based on the subject, and message text, or is user going to click on ad. It answers the question to which category something belongs, based on the given set of attributes
  •  Regression  – try to predict  some continuous numeric value based on set of inputs. For example, predict sales for the given marketing budget. It answers the question what will be the value of something, for the given value of something related.

Note that there are also other types of tasks which are not included in this quick introductory overview, but the general principles described here are the same.

Example Code

Here is an example code to train very simple neural network for regression using Feed forward neural network from Deep Netts deep learning library. This type of machine learning model can be used to roughly predict [example use case] More details about the linear regression and feed forward network is explained in the following post For now take a look at the code for the simplest example for loading data from CSV file and training a machine learning model.

// Load data from CSV file
DataSet trainingSet = DataSets.readCsv("fileName.csv", inputsNum, outputsNum);
// Create a feed forward neural network using builder
FeedForwardNetwork neuralNet = FeedForwardNetwork.builder()
                               .addOutputLayer(outputsNum, ActivationType.LINEAR)
// Train network
// use the trained model (neural network) for prediction
float[] prediction = neuralNet.getOutput();

Now when you’re familiar with basic concepts, Deep Netts API provides very easy and intuitive way to use machine learning and deep learning in Java:

  • Class DataSet holds the data that is used for training a machine learning algorithm – a training set.
  • Class DataSets provides various static utility methods to work with data, and one of them is readCsv which loads data from CSV file and returns an instance of DataSet that is used as training set.
  • Class FeedForwardNetwork provides widely used type of machine learning technique which can be used for both classification an regression problems.  It provides the train method which perfroms the training/learning procedure on feed forward network using given trainingSet parameter.

Note: Some links are still under construction.


Introduction to Deep Learning: from linear regression to convolutional networks

Key concepts: Machine Learning, Classification, Regression, Model, Training Set, FeedForwardNetwork

Additional content

Slides from Deep Java Dev Meetup on this topic.

[1] Geoffrey Hinton in his interview to Andrew Ng