ml.shifu.guagua.yarn.example.nn.meta
类 NNParams

java.lang.Object
  继承者 ml.shifu.guagua.io.HaltBytable
      继承者 ml.shifu.guagua.yarn.example.nn.meta.NNParams
所有已实现的接口:
ml.shifu.guagua.io.Bytable

public class NNParams
extends ml.shifu.guagua.io.HaltBytable

NNParams are used to save NN model info which can also be stored into ZooKeeper.

weights is used to set model weights which is used to transfer info from master to workers.

gradients is used to accumulate all workers' gradients together in master and then use the accumulated gradients to update neural network weights.


构造方法摘要
NNParams()
           
 
方法摘要
 void accumulateGradients(double[] gradients)
           
 void accumulateTrainSize(long size)
           
 void doReadFields(DataInput in)
           
 void doWrite(DataOutput out)
           
 double[] getGradients()
           
 double getTestError()
           
 double getTrainError()
           
 long getTrainSize()
           
 double[] getWeights()
           
 void reset()
           
 void setGradients(double[] gradients)
           
 void setTestError(double testError)
           
 void setTrainError(double trainError)
           
 void setTrainSize(long trainSize)
           
 void setWeights(double[] weights)
           
 String toString()
           
 
从类 ml.shifu.guagua.io.HaltBytable 继承的方法
isHalt, readFields, setHalt, write
 
从类 java.lang.Object 继承的方法
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

构造方法详细信息

NNParams

public NNParams()
方法详细信息

getWeights

public double[] getWeights()

setWeights

public void setWeights(double[] weights)

getTestError

public double getTestError()

setTestError

public void setTestError(double testError)

getTrainError

public double getTrainError()

setTrainError

public void setTrainError(double trainError)

accumulateGradients

public void accumulateGradients(double[] gradients)

getGradients

public double[] getGradients()
返回:
the gradients

setGradients

public void setGradients(double[] gradients)
参数:
gradients - the gradients to set

getTrainSize

public long getTrainSize()

setTrainSize

public void setTrainSize(long trainSize)

accumulateTrainSize

public void accumulateTrainSize(long size)

reset

public void reset()

doWrite

public void doWrite(DataOutput out)
             throws IOException
指定者:
ml.shifu.guagua.io.HaltBytable 中的 doWrite
抛出:
IOException

doReadFields

public void doReadFields(DataInput in)
                  throws IOException
指定者:
ml.shifu.guagua.io.HaltBytable 中的 doReadFields
抛出:
IOException

toString

public String toString()
覆盖:
Object 中的 toString


Copyright © 2014. All Rights Reserved.