Package deepnetts.net.loss
Class BinaryCrossEntropyLoss
java.lang.Object
deepnetts.net.loss.BinaryCrossEntropyLoss
- All Implemented Interfaces:
LossFunction
,Serializable
Cross Entropy Loss is a loss function used for binary classification tasks (two classes, single output which represents probability ).
It should be used in combination with sigmoid output activation function.
The formula:
E = (1/n) * -SUM( t * ln(y) + (1-t) * ln(1-y) )
where t is target, and y actual output
Bishop, C. pg. 231, eq. 6.120
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionfloat[]
addPatternError
(float[] pred, float[] target) Calculates error for given actual and target patterns and adds that error to total error.addPatternError
(TensorBase predictedOutput, TensorBase targetOutput) void
addRegularizationSum
(float regSum) Adds specified regularization sum to total loss.float
float
getTotal()
Returns the total error calculated by this loss function.void
reset()
Resets the total error and pattern counter.Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface deepnetts.net.loss.LossFunction
valueFor
-
Constructor Details
-
Method Details
-
addPatternError
public float[] addPatternError(float[] pred, float[] target) Calculates error for given actual and target patterns and adds that error to total error. Returns output error vector for specified actual and target outputs.- Specified by:
addPatternError
in interfaceLossFunction
- Parameters:
pred
- predicted output of a neural networktarget
- target output of a neural network- Returns:
- error vector for specified predicted and target outputs
-
addPatternError
- Specified by:
addPatternError
in interfaceLossFunction
-
getPatternLoss
public float getPatternLoss()- Specified by:
getPatternLoss
in interfaceLossFunction
-
addRegularizationSum
public void addRegularizationSum(float regSum) Description copied from interface:LossFunction
Adds specified regularization sum to total loss.- Specified by:
addRegularizationSum
in interfaceLossFunction
- Parameters:
regSum
- regularization sum
-
getTotal
public float getTotal()Description copied from interface:LossFunction
Returns the total error calculated by this loss function.- Specified by:
getTotal
in interfaceLossFunction
- Returns:
- total error calculated by this loss function
-
reset
public void reset()Description copied from interface:LossFunction
Resets the total error and pattern counter.- Specified by:
reset
in interfaceLossFunction
-