public class GradientWorker extends java.lang.Object implements EngineTask
| Modifier and Type | Field and Description |
|---|---|
protected java.util.Random |
dropoutRandomSource
Used to generate randomness for dropout
|
| Constructor and Description |
|---|
GradientWorker(FlatNetwork theNetwork,
GradientWorkerOwner theOwner,
MLDataSet theTraining,
int theLow,
int theHigh,
double[] flatSpot,
ErrorFunction ef)
Construct a gradient worker.
|
| Modifier and Type | Method and Description |
|---|---|
void |
calculateRegularizationPenalty(double[] l) |
ErrorCalculation |
getErrorCalculation() |
double[] |
getGradients() |
FlatNetwork |
getNetwork() |
double[] |
getWeights() |
void |
layerRegularizationPenalty(int fromLayer,
double[] l) |
void |
process(MLDataPair pair)
Process one training set element.
|
void |
run()
Perform the gradient calculation for the specified index range.
|
void |
run(int index) |
protected java.util.Random dropoutRandomSource
public GradientWorker(FlatNetwork theNetwork, GradientWorkerOwner theOwner, MLDataSet theTraining, int theLow, int theHigh, double[] flatSpot, ErrorFunction ef)
theNetwork - The network to train.theOwner - The owner that is doing the training.theTraining - The training data.theLow - The low index to use in the training data.theHigh - The high index to use in the training data.flatSpot - The flatspot additions for each layeref - Error functionpublic FlatNetwork getNetwork()
public double[] getWeights()
public void process(MLDataPair pair)
pair - the training data informationpublic final void run()
run in interface EngineTaskpublic final void run(int index)
public ErrorCalculation getErrorCalculation()
public double[] getGradients()
public void calculateRegularizationPenalty(double[] l)
public void layerRegularizationPenalty(int fromLayer,
double[] l)