|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||
java.lang.Objectorg.encog.engine.network.train.prop.TrainFlatNetworkProp
public abstract class TrainFlatNetworkProp
Train a flat network using multithreading, and GPU support. The training data must be indexable, it will be broken into groups for each thread to process. At the end of each iteration the training from each thread is aggregated back to the neural network.
| Field Summary | |
|---|---|
protected double |
currentError
The current error is the average error over all of the threads. |
protected double[] |
gradients
The gradients. |
protected FlatNetwork |
network
The network to train. |
| Constructor Summary | |
|---|---|
TrainFlatNetworkProp(FlatNetwork network,
EngineDataSet training)
Train a flat network multithreaded. |
|
| Method Summary | |
|---|---|
void |
calculateGradients()
Calculate the gradients. |
void |
finishTraining()
Training is to stop, free any resources. |
double |
getError()
|
int |
getIteration()
|
double[] |
getLastGradient()
|
FlatNetwork |
getNetwork()
|
int |
getNumThreads()
|
EngineDataSet |
getTraining()
|
void |
iteration()
Perform one training iteration. |
void |
iteration(int count)
Perform the specified number of training iterations. |
protected void |
learn()
Apply and learn. |
protected void |
learnLimited()
Apply and learn. |
void |
report(double[] gradients,
double error,
Throwable ex)
Called by the worker threads to report the progress at each step. |
void |
setIteration(int iteration)
Set the iteration. |
void |
setNumThreads(int numThreads)
Set the number of threads to use. |
abstract double |
updateWeight(double[] gradients,
double[] lastGradient,
int index)
Update a weight, the means by which weights are updated vary depending on the training. |
| Methods inherited from class java.lang.Object |
|---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
| Field Detail |
|---|
protected double[] gradients
protected final FlatNetwork network
protected double currentError
| Constructor Detail |
|---|
public TrainFlatNetworkProp(FlatNetwork network,
EngineDataSet training)
network - The network to train.training - The training data to use.| Method Detail |
|---|
public void calculateGradients()
public void finishTraining()
finishTraining in interface TrainFlatNetworkpublic double getError()
getError in interface TrainFlatNetworkpublic double[] getLastGradient()
public FlatNetwork getNetwork()
getNetwork in interface TrainFlatNetworkpublic int getNumThreads()
getNumThreads in interface TrainFlatNetworkpublic EngineDataSet getTraining()
getTraining in interface TrainFlatNetworkpublic void iteration()
iteration in interface TrainFlatNetworkprotected void learn()
protected void learnLimited()
public void report(double[] gradients,
double error,
Throwable ex)
gradients - The gradients from that worker.error - The error for that worker.ex - The exception.public void setNumThreads(int numThreads)
setNumThreads in interface TrainFlatNetworknumThreads - The number of threads to use.
public abstract double updateWeight(double[] gradients,
double[] lastGradient,
int index)
gradients - The gradients.lastGradient - The last gradients.index - The index.
public void iteration(int count)
iteration in interface TrainFlatNetworkcount - The number of training iterations.public int getIteration()
getIteration in interface TrainFlatNetworkpublic void setIteration(int iteration)
setIteration in interface TrainFlatNetworkiteration - The iteration.
|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||