Package deepnetts.net.train
Class BackpropagationTrainer
java.lang.Object
deepnetts.net.train.BackpropagationTrainer
- All Implemented Interfaces:
Trainer
,Serializable
Backpropagation training algorithm for feed forward and convolutional neural networks.
Backpropagation is a supervised machine learning algorithm which iteratively
reduces prediction error, by looking for the minimum of loss function.
- See Also:
-
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final String
Name of the batchMode propertystatic final String
Name of the batchSize propertystatic final String
Name of the learningRate propertystatic final String
Name of the maxEpochs propertystatic final String
Name of the maxError propertystatic final String
Name of the momentum propertystatic final String
Name of the optimizer property -
Constructor Summary
ConstructorsConstructorDescriptionBackpropagationTrainer
(NeuralNetwork neuralNet) Creates an instance of BackpropagationTrainer for the given neural network to train.Creates an instance of BackpropagationTrainer with the given properties. -
Method Summary
Modifier and TypeMethodDescriptionvoid
addListener
(TrainingListener listener) Adds training listener to this trainer.boolean
Returns true if network creates training snapshots, false otherwise.int
Batch size is number of training examples after which network's weights are adjusted.int
On how many epochs the snapshots of the trained network should be created.int
Returns the current training epoch(iteration) of this trainer.float
Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.boolean
Early stopping stops training if it starts converging slow, and prevents overfitting.float
Early stopping stops training if the error/loss start converging to slow.int
How many epochs to wait to see if the loss is lowering to slow.boolean
Extended logging includes additional info for debugging the training.float
Learning rate controls the step size as a percent of the error to use for adjusting internal parameters(weights) of the neural network.long
Returns the setting for maximum number of training epochs(iterations).float
Returns the setting for the stopping error threshold.float
Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training.Returns a neural network trained by this trainer.boolean
Returns shuffle flag which determines if training set should be shuffled before each epoch.int
On how many epochs to make training snapshots.Path to use for making snapshots - saving the current state of trained network during the training in order to be able to restore it from a training point if needed.float
float
Alias for getMaxError().javax.visrec.ml.data.DataSet
<?> Test set is used after the training to estimate performance of the trained model and generalization ability with new data.float
Accuracy metric which tells us a percent of correct predictions for training set.float
Total training error/loss at the current epoch.float
Accuracy metric which tells us a percent of correct predictions for validation set.float
Validation loss is an error calculated using validation set, used to prevent overfitting, and validate architecture and training settings.boolean
In batch mode weights are adjusted after the pass of all examples from the training set, while in online mode weights are adjusted after each training example.void
removeListener
(TrainingListener listener) Removes training listener from this trainer.setBatchMode
(boolean batchMode) Sets flag whether to use batch mode during the training.setBatchSize
(int batchSize) Batch size is number of training examples after which network's weights are adjusted.setCheckpointEpochs
(int checkpointEpochs) On how many epochs the snapshots of the trained network should be created.setDropout
(float dropout) Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.setEarlyStopping
(boolean earlyStopping) Early stopping stops training if it starts converging slow, and prevents overfitting.setEarlyStoppingMinLossChange
(float earlyStoppingMinLossChange) Early stopping stops training if the error/loss start converging to slow.setEarlyStoppingPatience
(int earlyStoppingPatience) How many epochs to wait to see if the loss is lowering to slow.void
setExtendedLogging
(boolean extendedLogging) Extended logging includes additional info for debugging the training.setL1Regularization
(float regL1) L1 regularization (sum of abs values) is used to prevent overfitting and too large weights.setL2Regularization
(float regL2) L2 regularization (sum of squares) is used to prevent overfitting and too large weights.setLearningRate
(float learningRate) Learning rate controls the step size as a percent of the error to use for adjusting internal parameters(weights) of the neural network.setLearningRateDecay
(float decayRate) Learning rate decay lowers the learning rate with each epoch by devayRate factor, which may improve error lowering the error.setMaxEpochs
(long maxEpochs) Deprecated.Use setStopEpochs insteadsetMaxError
(float maxError) Deprecated.Use setStopError insteadsetMomentum
(float momentum) Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training.setOptimizer
(OptimizerType optimizer) final void
setProperties
(Properties prop) Sets properties from available keys in specified prop object.setShuffle
(boolean shuffle) Sets shuffle flag which determines if training set should be shuffled before each epoch.void
setSnapshotEpochs
(int snapshotEpochs) On how many epochs to make training snapshots.setSnapshotPath
(String snapshotPath) Path to use for making snapshots - saving the current state of trained network during the training in order to be able to restore it from a training point.setStopAccuracy
(float stopAccuracy) setStopEpochs
(long stopEpochs) Sets number of epochs/iterations to run the training.setStopError
(float stopError) The training stops when/if training error reach this value.void
setTestSet
(javax.visrec.ml.data.DataSet<MLDataItem> testSet) Test set is used after the training to estimate performance of the trained model and generalization ability with new data.void
setTrainingSnapshots
(boolean trainingSnapshots) Training snapshots save the current state of the trained neural network during the training in order to be able to restore it from a training point if needed.void
stop()
Stops the training.void
train
(javax.visrec.ml.data.DataSet<?> trainingSet, double valSplit) Run training using given training set, and split part of it to use as a validation set.void
train
(javax.visrec.ml.data.DataSet<? extends MLDataItem> trainingSet) Runs training using specified training set.void
train
(javax.visrec.ml.data.DataSet<MLDataItem> trainingSet, javax.visrec.ml.data.DataSet<MLDataItem> validationSet) Runs training using given training and validation sets.void
updateLearningRate
(float learningRate) Updates learning rate for all layers during the learning rate decay.
-
Field Details
-
PROP_MAX_ERROR
-
PROP_MAX_EPOCHS
-
PROP_LEARNING_RATE
-
PROP_MOMENTUM
-
PROP_BATCH_MODE
-
PROP_BATCH_SIZE
-
PROP_OPTIMIZER_TYPE
-
-
Constructor Details
-
BackpropagationTrainer
Creates an instance of BackpropagationTrainer for the given neural network to train.- Parameters:
neuralNet
- neural network to train using this instance of backpropagation algorithm
-
BackpropagationTrainer
Creates an instance of BackpropagationTrainer with the given properties.- Parameters:
prop
- key,value pairs of properties for backpropagation
-
-
Method Details
-
train
public void train(javax.visrec.ml.data.DataSet<MLDataItem> trainingSet, javax.visrec.ml.data.DataSet<MLDataItem> validationSet) Runs training using given training and validation sets. Training set is used to train model, while validation set is used to check model evaluation metrics during the training. with unseen data in order to prevent over-fitting. Note that validation set is different from test set which is used after the training in order to evaluate trained model.- Parameters:
trainingSet
- set of example data to train the networkvalidationSet
- set of example data to validate the network during the training
-
train
public void train(javax.visrec.ml.data.DataSet<?> trainingSet, double valSplit) Run training using given training set, and split part of it to use as a validation set.- Parameters:
trainingSet
- set of example data to train the networkvalSplit
- percent of training set to use as a validation set, value between 0 and 1, commonly something like 0.1, 0.2
-
train
Runs training using specified training set. Training is an iterative procedure during which network's internal parameters(weights) are adjusted in order to minimize prediction error for the given example data in training set. -
getMaxEpochs
public long getMaxEpochs()Returns the setting for maximum number of training epochs(iterations). Training stops when the specified number of training epochs or error threshold (stopError) is reached- Returns:
- max training epochs
-
setMaxEpochs
Deprecated.Use setStopEpochs insteadSets maximum number of training epochs(iterations) for training the network. Epoch is a single pass of all trainings examples from the training set. The training will stop after the specified number of epochs, if the network does not reach some other stopping condition before (like error threshold).- Parameters:
maxEpochs
- the maximum number of training epochs(iterations) for training the network- Returns:
- this trainer
- See Also:
-
setStopEpochs
Sets number of epochs/iterations to run the training. When this number of epochs is reached the training will stop, if target accuracy has not been reached before.- Parameters:
stopEpochs
- number of epochs after which training will stop- Returns:
- this trainer
- See Also:
-
getMaxError
public float getMaxError()Returns the setting for the stopping error threshold. The training stops when total network error reaches this value.- Returns:
- stop error threshold
-
getStopError
public float getStopError()Alias for getMaxError().- Returns:
-
setMaxError
Deprecated.Use setStopError insteadSets stopping error threshold for this training. The training will stop when/if training error reach this value. This method will be deprecated and setStopError method should be used instead, as more intuitive.- Parameters:
maxError
- maximum error threshold- Returns:
- this trainer
-
setStopError
The training stops when/if training error reach this value.- Parameters:
stopError
- value of training error to stop the training- Returns:
- this trainer
-
getStopAccuracy
public float getStopAccuracy() -
setStopAccuracy
-
setLearningRate
Learning rate controls the step size as a percent of the error to use for adjusting internal parameters(weights) of the neural network. With too large values training may cannot be completed and error will grow, while with too small values training might last too long or get stuck in local minimum. Commonly used default value for this setting is 0.01, which practically means that 1% of the error will be used for weight modification.- Parameters:
learningRate
- a value in range (0, 1), where 0.01 is being used as a default initial value- Returns:
- this trainer
-
getLearningRate
public float getLearningRate()Learning rate controls the step size as a percent of the error to use for adjusting internal parameters(weights) of the neural network. With too large values training may cannot be completed and error will grow, while with too small values training might last too long or get stuck in local minimum. Commonly used default value for this setting is 0.01, which practically means that 1% of the error will be used for weight modification.- Returns:
-
getNeuralNetwork
Returns a neural network trained by this trainer.- Returns:
- instance of a neural network trained by this trainer
-
updateLearningRate
public void updateLearningRate(float learningRate) Updates learning rate for all layers during the learning rate decay. Used by LearningRateDecay technique.- Parameters:
learningRate
- a value of learning rate to set for all layers- See Also:
-
setLearningRateDecay
Learning rate decay lowers the learning rate with each epoch by devayRate factor, which may improve error lowering the error.- Parameters:
decayRate
-- Returns:
- this trainer
-
setL2Regularization
L2 regularization (sum of squares) is used to prevent overfitting and too large weights.- Parameters:
regL2
- coefficient for L2 regularization- Returns:
- this trainer
-
setL1Regularization
L1 regularization (sum of abs values) is used to prevent overfitting and too large weights.- Parameters:
regL1
- coefficient for L1 regularization- Returns:
- this trainer
-
getShuffle
public boolean getShuffle()Returns shuffle flag which determines if training set should be shuffled before each epoch.- Returns:
- value of the shuffle flag
-
setShuffle
Sets shuffle flag which determines if training set should be shuffled before each epoch.- Parameters:
shuffle
-- Returns:
- this trainer
-
addListener
Adds training listener to this trainer.- Parameters:
listener
- object that listens for the events in this trainer
-
removeListener
Removes training listener from this trainer.- Parameters:
listener
- listener to remove
-
isBatchMode
public boolean isBatchMode()In batch mode weights are adjusted after the pass of all examples from the training set, while in online mode weights are adjusted after each training example.- See Also:
-
setBatchMode
Sets flag whether to use batch mode during the training. In batch mode weights are adjusted after the pass of all examples from the training set, while in online mode weights are adjusted after each training example.- Parameters:
batchMode
-- Returns:
- this trainer
-
getBatchSize
public int getBatchSize()Batch size is number of training examples after which network's weights are adjusted.- Returns:
-
setBatchSize
Batch size is number of training examples after which network's weights are adjusted.- Parameters:
batchSize
-- Returns:
-
setMomentum
Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training. It has effect only if momentum optimizer is used.- Parameters:
momentum
- a decimal value greater than zero and less than one- Returns:
-
getMomentum
public float getMomentum()Momentum settings helps to avoid oscillations in weight changes and get more stable and faster training. It has effect only if momentum optimizer is used.- Returns:
-
stop
public void stop()Stops the training. -
getTrainingLoss
public float getTrainingLoss()Total training error/loss at the current epoch. The error is calculated using loss function and is referred to also as a loss.- Returns:
- total training error/loss at the current epoch.
-
getValidationLoss
public float getValidationLoss()Validation loss is an error calculated using validation set, used to prevent overfitting, and validate architecture and training settings.- Returns:
- error/loss calculated usng validation set
-
getTrainingAccuracy
public float getTrainingAccuracy()Accuracy metric which tells us a percent of correct predictions for training set.- Returns:
- classification accuracy for the training examples
-
getValidationAccuracy
public float getValidationAccuracy()Accuracy metric which tells us a percent of correct predictions for validation set.- Returns:
- classification accuracy for examples in validation set
-
getCurrentEpoch
public int getCurrentEpoch()Returns the current training epoch(iteration) of this trainer. Epoch is one pass of all examples from a training set.- Returns:
- current training epoch
-
getOptimizer
-
setOptimizer
-
getTestSet
public javax.visrec.ml.data.DataSet<?> getTestSet()Test set is used after the training to estimate performance of the trained model and generalization ability with new data. Examples (data) from test set should never be used during the training. Tests set is commonly generated by splitting all available data in training and test sets in some ratio.- Returns:
- test set - example data not used during the training, that will be used for evaluation/testing of the trained model
-
setTestSet
Test set is used after the training to estimate performance of the trained model and generalization ability with new data. Examples (data) from test set should never be used during the training. Tests set is commonly generated by splitting all available data in training and test sets in some ratio.- Parameters:
testSet
- example data not used during the training, that will be used for evaluation/testing of the trained model
-
getEarlyStopping
public boolean getEarlyStopping()Early stopping stops training if it starts converging slow, and prevents overfitting.- Returns:
-
setEarlyStopping
Early stopping stops training if it starts converging slow, and prevents overfitting.- Parameters:
earlyStopping
-- Returns:
- this trainer
-
setSnapshotPath
Path to use for making snapshots - saving the current state of trained network during the training in order to be able to restore it from a training point.- Parameters:
snapshotPath
-- Returns:
- this trainer
-
getSnapshotPath
Path to use for making snapshots - saving the current state of trained network during the training in order to be able to restore it from a training point if needed.- Returns:
- directory to store snapshots of the neural networks during the training
-
getSnapshotEpochs
public int getSnapshotEpochs()On how many epochs to make training snapshots.- Returns:
-
setSnapshotEpochs
public void setSnapshotEpochs(int snapshotEpochs) On how many epochs to make training snapshots.- Parameters:
snapshotEpochs
-
-
setTrainingSnapshots
public void setTrainingSnapshots(boolean trainingSnapshots) Training snapshots save the current state of the trained neural network during the training in order to be able to restore it from a training point if needed.- Parameters:
trainingSnapshots
-
-
createsTrainingSnaphots
public boolean createsTrainingSnaphots()Returns true if network creates training snapshots, false otherwise. Training snapshots save the current state of the trained neural network during the training in order to be able to restore it from a training point if needed.- Returns:
-
getEarlyStoppingMinLossChange
public float getEarlyStoppingMinLossChange()Early stopping stops training if the error/loss start converging to slow. If the loss change is lower than given value for patience epochs the training will stop.- Returns:
-
setEarlyStoppingMinLossChange
Early stopping stops training if the error/loss start converging to slow. If the loss change is lower than given value for patience epochs the training will stop.- Parameters:
earlyStoppingMinLossChange
-- Returns:
- this trainer
-
getEarlyStoppingPatience
public int getEarlyStoppingPatience()How many epochs to wait to see if the loss is lowering to slow.- Returns:
-
setEarlyStoppingPatience
How many epochs to wait to see if the loss is lowering to slow.- Parameters:
earlyStoppingPatience
-- Returns:
-
getCheckpointEpochs
public int getCheckpointEpochs()On how many epochs the snapshots of the trained network should be created.- Returns:
-
setCheckpointEpochs
On how many epochs the snapshots of the trained network should be created.- Parameters:
checkpointEpochs
-- Returns:
-
setProperties
Sets properties from available keys in specified prop object.- Parameters:
prop
-
-
setDropout
Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.- Parameters:
dropout
- value between 0.2 and 0.8 which represents probability to skip adjusting weights- Returns:
- this trainer
-
getDropout
public float getDropout()Dropout is a technique to prevent overfitting, which skips adjusting weights for some neurons with given probability.- Returns:
- value between 0.2 and 0.8 which represents probability to skip adjusting weights
-
getExtendedLogging
public boolean getExtendedLogging()Extended logging includes additional info for debugging the training.- Returns:
-
setExtendedLogging
public void setExtendedLogging(boolean extendedLogging) Extended logging includes additional info for debugging the training.- Parameters:
extendedLogging
-
-