org.encog.neural.flat.train.prop
Class TrainFlatNetworkProp

java.lang.Object
  extended by org.encog.neural.flat.train.prop.TrainFlatNetworkProp
All Implemented Interfaces:
TrainFlatNetwork
Direct Known Subclasses:
TrainFlatNetworkBackPropagation, TrainFlatNetworkManhattan, TrainFlatNetworkQPROP, TrainFlatNetworkResilient, TrainFlatNetworkSCG

public abstract class TrainFlatNetworkProp
extends Object
implements TrainFlatNetwork

Train a flat network using multithreading, and GPU support. The training data must be indexable, it will be broken into groups for each thread to process. At the end of each iteration the training from each thread is aggregated back to the neural network.


Field Summary
protected  double currentError
          The current error is the average error over all of the threads.
protected  double[] gradients
          The gradients.
protected  double lastError
          The last error.
protected  FlatNetwork network
          The network to train.
 
Constructor Summary
TrainFlatNetworkProp(FlatNetwork network, MLDataSet training)
          Train a flat network multithreaded.
 
Method Summary
 void calculateGradients()
          Calculate the gradients.
 void finishTraining()
          Training is to stop, free any resources.
 void fixFlatSpot(boolean e)
           
 double getError()
          
 ErrorFunction getErrorFunction()
           
 int getIteration()
          
 double[] getLastGradient()
           
 FlatNetwork getNetwork()
          
 int getNumThreads()
          
 MLDataSet getTraining()
          
abstract  void initOthers()
           
 void iteration()
          Perform one training iteration.
 void iteration(int count)
          Perform the specified number of training iterations.
protected  void learn()
          Apply and learn.
protected  void learnLimited()
          Apply and learn.
 void report(double[] gradients, double error, Throwable ex)
          Called by the worker threads to report the progress at each step.
 void setErrorFunction(ErrorFunction ef)
          Set the error function.
 void setIteration(int iteration)
          Set the iteration.
 void setNumThreads(int numThreads)
          Set the number of threads to use.
abstract  double updateWeight(double[] gradients, double[] lastGradient, int index)
          Update a weight, the means by which weights are updated vary depending on the training.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

gradients

protected double[] gradients
The gradients.


network

protected final FlatNetwork network
The network to train.


currentError

protected double currentError
The current error is the average error over all of the threads.


lastError

protected double lastError
The last error.

Constructor Detail

TrainFlatNetworkProp

public TrainFlatNetworkProp(FlatNetwork network,
                            MLDataSet training)
Train a flat network multithreaded.

Parameters:
network - The network to train.
training - The training data to use.
Method Detail

calculateGradients

public void calculateGradients()
Calculate the gradients.


finishTraining

public void finishTraining()
Training is to stop, free any resources.

Specified by:
finishTraining in interface TrainFlatNetwork

getError

public final double getError()

Specified by:
getError in interface TrainFlatNetwork
Returns:
The error from the neural network.

getIteration

public final int getIteration()

Specified by:
getIteration in interface TrainFlatNetwork
Returns:
The current iteration.

getLastGradient

public final double[] getLastGradient()
Returns:
The gradients from the last iteration;

getNetwork

public final FlatNetwork getNetwork()

Specified by:
getNetwork in interface TrainFlatNetwork
Returns:
The trained neural network.

getNumThreads

public final int getNumThreads()

Specified by:
getNumThreads in interface TrainFlatNetwork
Returns:
The number of threads.

getTraining

public final MLDataSet getTraining()

Specified by:
getTraining in interface TrainFlatNetwork
Returns:
The data we are training with.

fixFlatSpot

public void fixFlatSpot(boolean e)

iteration

public void iteration()
Perform one training iteration.

Specified by:
iteration in interface TrainFlatNetwork

iteration

public final void iteration(int count)
Perform the specified number of training iterations. This is a basic implementation that just calls iteration the specified number of times. However, some training methods, particularly with the GPU, benefit greatly by calling with higher numbers than 1.

Specified by:
iteration in interface TrainFlatNetwork
Parameters:
count - The number of training iterations.

learn

protected void learn()
Apply and learn.


learnLimited

protected void learnLimited()
Apply and learn. This is the same as learn, but it checks to see if any of the weights are below the limit threshold. In this case, these weights are zeroed out. Having two methods allows the regular learn method, which is what is usually use, to be as fast as possible.


report

public final void report(double[] gradients,
                         double error,
                         Throwable ex)
Called by the worker threads to report the progress at each step.

Parameters:
gradients - The gradients from that worker.
error - The error for that worker.
ex - The exception.

setIteration

public void setIteration(int iteration)
Set the iteration.

Specified by:
setIteration in interface TrainFlatNetwork
Parameters:
iteration - The iteration.

setNumThreads

public void setNumThreads(int numThreads)
Set the number of threads to use.

Specified by:
setNumThreads in interface TrainFlatNetwork
Parameters:
numThreads - The number of threads to use.

updateWeight

public abstract double updateWeight(double[] gradients,
                                    double[] lastGradient,
                                    int index)
Update a weight, the means by which weights are updated vary depending on the training.

Parameters:
gradients - The gradients.
lastGradient - The last gradients.
index - The index.
Returns:
The update value.

setErrorFunction

public void setErrorFunction(ErrorFunction ef)
Set the error function.

Parameters:
ef - The error function.

getErrorFunction

public ErrorFunction getErrorFunction()
Returns:
The error function.

initOthers

public abstract void initOthers()


Copyright © 2011. All Rights Reserved.