public class Backpropagation extends Propagation implements Momentum, LearningRate
| Modifier and Type | Field and Description |
|---|---|
static java.lang.String |
LAST_DELTA
The resume key for backpropagation.
|
dropoutRandomSource, gradients, network| Constructor and Description |
|---|
Backpropagation(ContainsFlat network,
MLDataSet training)
Create a class to train using backpropagation.
|
Backpropagation(ContainsFlat network,
MLDataSet training,
double theLearnRate,
double theMomentum) |
| Modifier and Type | Method and Description |
|---|---|
boolean |
canContinue() |
double[] |
getLastDelta() |
double |
getLearningRate() |
double |
getMomentum() |
void |
initOthers()
Perform training method specific init.
|
boolean |
isNesterovUpdate() |
boolean |
isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.
|
TrainingContinuation |
pause()
Pause the training.
|
void |
resume(TrainingContinuation state)
Resume training.
|
void |
setLearningRate(double rate)
Set the learning rate, this is value is essentially a percent.
|
void |
setMomentum(double m)
Set the momentum for training.
|
void |
setNesterovUpdate(boolean nesterovUpdate) |
double |
updateWeight(double[] gradients,
double[] lastGradient,
int index)
Update a weight.
|
double |
updateWeight(double[] gradients,
double[] lastGradient,
int index,
double dropoutRate)
Update a weight.
|
calculateGradients, finishTraining, finishTraining, fixFlatSpot, getBatchSize, getCurrentFlatNetwork, getDropoutRate, getL1, getL2, getLastGradient, getMethod, getThreadCount, iteration, iteration, learn, learnLimited, report, rollIteration, setBatchSize, setDroupoutRate, setErrorFunction, setL1, setL2, setThreadCountaddStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTrainingclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitaddStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, setError, setIterationpublic static final java.lang.String LAST_DELTA
public Backpropagation(ContainsFlat network, MLDataSet training)
network - The network that is to be trained.training - The training data to be used for backpropagation.public Backpropagation(ContainsFlat network, MLDataSet training, double theLearnRate, double theMomentum)
network - The network that is to be trainedtraining - The training settheLearnRate - The rate at which the weight matrix will be adjusted based on
learning.theMomentum - The influence that previous iteration's training deltas will
have on the current iteration.public boolean canContinue()
canContinue in interface MLTrainpublic double[] getLastDelta()
public double getLearningRate()
getLearningRate in interface LearningRatepublic double getMomentum()
getMomentum in interface Momentumpublic boolean isValidResume(TrainingContinuation state)
state - The continuation object to check.public TrainingContinuation pause()
public void resume(TrainingContinuation state)
public void setLearningRate(double rate)
setLearningRate in interface LearningRaterate - The learning rate.public void setMomentum(double m)
setMomentum in interface Momentumm - The momentum.public double updateWeight(double[] gradients,
double[] lastGradient,
int index)
updateWeight in class Propagationgradients - The gradients.lastGradient - The last gradients.index - The index.public double updateWeight(double[] gradients,
double[] lastGradient,
int index,
double dropoutRate)
updateWeight in class Propagationgradients - The gradients.lastGradient - The last gradients.index - The index.dropoutRate - The dropout rate.public void initOthers()
initOthers in class Propagationpublic boolean isNesterovUpdate()
public void setNesterovUpdate(boolean nesterovUpdate)