org.encog.engine.network.train.prop
Class TrainFlatNetworkOpenCL

java.lang.Object
  extended by org.encog.engine.network.train.prop.TrainFlatNetworkOpenCL
All Implemented Interfaces:
TrainFlatNetwork

public class TrainFlatNetworkOpenCL
extends Object
implements TrainFlatNetwork

Train a flat network using OpenCL.


Field Summary
static int LEARN_BPROP
          Learn backpropagation.
static int LEARN_MANHATTAN
          Learn Manhattan update rule.
static int LEARN_RPROP
          Learn RPROP.
 
Constructor Summary
TrainFlatNetworkOpenCL(FlatNetwork network, EngineDataSet training, OpenCLTrainingProfile profile)
          Train a flat network multithreaded.
 
Method Summary
 void finishTraining()
          Training is to stop, free any resources.
 double getError()
          
 int getIteration()
          
 double[] getLastGradient()
           
 double getLearningRate()
           
 int getLearningType()
           
 double getMaxStep()
           
 double getMomentum()
           
 FlatNetwork getNetwork()
          
 int getNumThreads()
          
 EngineDataSet getTraining()
           
 double[] getUpdateValues()
           
 void iteration()
          Perform one training iteration.
 void iteration(int iterations)
          Perform one training iteration.
 void learnBPROP(double learningRate, double momentum)
          Learn using backpropagation.
 void learnManhattan(double learningRate)
          Learn using the Manhattan update rule.
 void learnRPROP()
          Learn using RPROP.
 void learnRPROP(double initialUpdate, double maxStep)
          Learn using RPROP with a custom initial update and max step.
 void setIteration(int iteration)
          Set the iteration.
 void setNumThreads(int numThreads)
          Set the number of threads to use.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

LEARN_RPROP

public static final int LEARN_RPROP
Learn RPROP.

See Also:
Constant Field Values

LEARN_BPROP

public static final int LEARN_BPROP
Learn backpropagation.

See Also:
Constant Field Values

LEARN_MANHATTAN

public static final int LEARN_MANHATTAN
Learn Manhattan update rule.

See Also:
Constant Field Values
Constructor Detail

TrainFlatNetworkOpenCL

public TrainFlatNetworkOpenCL(FlatNetwork network,
                              EngineDataSet training,
                              OpenCLTrainingProfile profile)
Train a flat network multithreaded.

Parameters:
network - The network to train.
training - The training data to use.
profile - The OpenCL training profile.
Method Detail

finishTraining

public void finishTraining()
Training is to stop, free any resources.

Specified by:
finishTraining in interface TrainFlatNetwork

getError

public double getError()

Specified by:
getError in interface TrainFlatNetwork
Returns:
The error from the neural network.

getIteration

public int getIteration()

Specified by:
getIteration in interface TrainFlatNetwork
Returns:
The current iteration.

getLastGradient

public double[] getLastGradient()
Returns:
The last gradients.

getLearningRate

public double getLearningRate()
Returns:
the learningRate

getLearningType

public int getLearningType()
Returns:
the learningType

getMaxStep

public double getMaxStep()
Returns:
the maxStep

getMomentum

public double getMomentum()
Returns:
the momentum

getNetwork

public FlatNetwork getNetwork()

Specified by:
getNetwork in interface TrainFlatNetwork
Returns:
The trained neural network.

getNumThreads

public int getNumThreads()

Specified by:
getNumThreads in interface TrainFlatNetwork
Returns:
The number of threads.

getTraining

public EngineDataSet getTraining()
Specified by:
getTraining in interface TrainFlatNetwork
Returns:
The training data to use.

getUpdateValues

public double[] getUpdateValues()
Returns:
The update values.

iteration

public void iteration()
Perform one training iteration.

Specified by:
iteration in interface TrainFlatNetwork

iteration

public void iteration(int iterations)
Perform one training iteration.

Specified by:
iteration in interface TrainFlatNetwork

learnBPROP

public void learnBPROP(double learningRate,
                       double momentum)
Learn using backpropagation.

Parameters:
learningRate - The learning rate.
momentum - The momentum.

learnManhattan

public void learnManhattan(double learningRate)
Learn using the Manhattan update rule.

Parameters:
learningRate - The learning rate.

learnRPROP

public void learnRPROP()
Learn using RPROP. Use default max step and initial update.


learnRPROP

public void learnRPROP(double initialUpdate,
                       double maxStep)
Learn using RPROP with a custom initial update and max step.

Parameters:
initialUpdate - The initial update value.
maxStep - The max step.

setIteration

public void setIteration(int iteration)
Set the iteration.

Specified by:
setIteration in interface TrainFlatNetwork
Parameters:
iteration - The iteration.

setNumThreads

public void setNumThreads(int numThreads)
Set the number of threads to use.

Specified by:
setNumThreads in interface TrainFlatNetwork
Parameters:
numThreads - The number of threads to use.


Copyright © 2011. All Rights Reserved.