Package deepnetts.net.loss
Class MeanSquaredErrorLoss
java.lang.Object
deepnetts.net.loss.MeanSquaredErrorLoss
- All Implemented Interfaces:
LossFunction
,Serializable
Mean Squared Error Loss function. Sum squared errors over all input patterns and
all outputs. Should be used for regression problems.
Math formula:
N K
E = 1/(2*N*K) * SUM(SUM(y-t)^2) + regSum
where N is number of patterns and K is dimension of output vector, and regSum is L1 or L2 regularization multiplied with lambda.
Bishop, pg. 89, eq. 3.34
Also recommended this formula in Proben1 Technical report
- See Also:
-
Constructor Summary
ConstructorsConstructorDescriptionMeanSquaredErrorLoss
(int layerWidth) MeanSquaredErrorLoss
(NeuralNetwork neuralNet) Creates a new mean squared error loss for the given neural network. -
Method Summary
Modifier and TypeMethodDescriptionfloat[]
addPatternError
(float[] predictedOutput, float[] targetOutput) Adds output error vector for the given predicted and target output vectors to total error sum and returns and error vector.addPatternError
(TensorBase predictedOut, TensorBase targetOut) void
addRegularizationSum
(float regSum) Add regularization sum to total lossfloat
float
getTotal()
Returns the total error calculated by this loss function.void
reset()
Resets the total error and pattern counter.Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface deepnetts.net.loss.LossFunction
valueFor
-
Constructor Details
-
MeanSquaredErrorLoss
Creates a new mean squared error loss for the given neural network.- Parameters:
neuralNet
-
-
MeanSquaredErrorLoss
public MeanSquaredErrorLoss(int layerWidth)
-
-
Method Details
-
addPatternError
public float[] addPatternError(float[] predictedOutput, float[] targetOutput) Adds output error vector for the given predicted and target output vectors to total error sum and returns and error vector.- Specified by:
addPatternError
in interfaceLossFunction
- Parameters:
predictedOutput
-targetOutput
-- Returns:
-
addPatternError
- Specified by:
addPatternError
in interfaceLossFunction
-
addRegularizationSum
public void addRegularizationSum(float regSum) Add regularization sum to total loss- Specified by:
addRegularizationSum
in interfaceLossFunction
- Parameters:
regSum
- regularization sum
-
getTotal
public float getTotal()Description copied from interface:LossFunction
Returns the total error calculated by this loss function.- Specified by:
getTotal
in interfaceLossFunction
- Returns:
- total error calculated by this loss function
-
reset
public void reset()Description copied from interface:LossFunction
Resets the total error and pattern counter.- Specified by:
reset
in interfaceLossFunction
-
getPatternLoss
public float getPatternLoss()- Specified by:
getPatternLoss
in interfaceLossFunction
-