public class StochasticGradientDescent extends BasicTraining implements Momentum, LearningRate
| Constructor and Description |
|---|
StochasticGradientDescent(ContainsFlat network,
MLDataSet training) |
StochasticGradientDescent(ContainsFlat network,
MLDataSet training,
GenerateRandom theRandom) |
| Modifier and Type | Method and Description |
|---|---|
void |
calculateRegularizationPenalty(double[] l) |
boolean |
canContinue() |
int |
getBatchSize() |
FlatNetwork |
getFlat() |
double |
getL1() |
double |
getL2() |
double |
getLearningRate() |
MLMethod |
getMethod()
Get the current best machine learning method from the training.
|
double |
getMomentum() |
UpdateRule |
getUpdateRule() |
boolean |
isValidResume(TrainingContinuation state) |
void |
iteration()
Perform one iteration of training.
|
void |
layerRegularizationPenalty(int fromLayer,
double[] l) |
TrainingContinuation |
pause()
Pause the training.
|
void |
preIteration()
Call the strategies before an iteration.
|
void |
process(MLDataPair pair) |
void |
resetError() |
void |
resume(TrainingContinuation state)
Resume training.
|
void |
setBatchSize(int theBatchSize) |
void |
setL1(double l1) |
void |
setL2(double l2) |
void |
setLearningRate(double rate)
Set the learning rate.
|
void |
setMomentum(double m)
Set the momentum.
|
void |
setUpdateRule(UpdateRule updateRule) |
void |
update() |
addStrategy, finishTraining, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, iteration, postIteration, setError, setIteration, setTrainingpublic StochasticGradientDescent(ContainsFlat network, MLDataSet training)
public StochasticGradientDescent(ContainsFlat network, MLDataSet training, GenerateRandom theRandom)
public void process(MLDataPair pair)
public void update()
public void resetError()
public void iteration()
MLTrainpublic boolean canContinue()
canContinue in interface MLTrainpublic double getLearningRate()
getLearningRate in interface LearningRatepublic double getMomentum()
getMomentum in interface Momentumpublic boolean isValidResume(TrainingContinuation state)
public TrainingContinuation pause()
public void resume(TrainingContinuation state)
MLTrainpublic MLMethod getMethod()
MLTrainpublic void setLearningRate(double rate)
LearningRatesetLearningRate in interface LearningRaterate - The new learning ratepublic void setMomentum(double m)
MomentumsetMomentum in interface Momentumm - The new momentum.public void preIteration()
BasicTrainingpreIteration in class BasicTrainingpublic int getBatchSize()
public void setBatchSize(int theBatchSize)
public double getL1()
public void setL1(double l1)
public double getL2()
public void setL2(double l2)
public void calculateRegularizationPenalty(double[] l)
public void layerRegularizationPenalty(int fromLayer,
double[] l)
public FlatNetwork getFlat()
public UpdateRule getUpdateRule()
public void setUpdateRule(UpdateRule updateRule)