Class RankNet
- java.lang.Object
-
- ciir.umass.edu.learning.Ranker
-
- ciir.umass.edu.learning.neuralnet.RankNet
-
- Direct Known Subclasses:
LambdaRank,ListNet
public class RankNet extends Ranker
- Author:
- vdang This class implements RankNet. C.J.C. Burges, T. Shaked, E. Renshaw, A. Lazier, M. Deeds, N. Hamilton and G. Hullender. Learning to rank using gradient descent. In Proc. of ICML, pages 89-96, 2005.
-
-
Field Summary
Fields Modifier and Type Field Description protected java.util.List<java.util.List<java.lang.Double>>bestModelOnValidationprotected doubleerrorprotected LayerinputLayerprotected doublelastErrorprotected java.util.List<Layer>layersstatic doublelearningRateprotected intmisorderedPairsstatic intnHiddenLayerstatic intnHiddenNodePerLayerstatic intnIterationprotected LayeroutputLayerprotected intstraightLossprotected inttotalPairs-
Fields inherited from class ciir.umass.edu.learning.Ranker
bestScoreOnValidationData, features, samples, scoreOnTrainingData, scorer, validationSamples, verbose
-
-
Constructor Summary
Constructors Constructor Description RankNet()RankNet(java.util.List<RankList> samples, int[] features, MetricScorer scorer)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected voidaddHiddenLayer(int size)protected voidaddInput(DataPoint p)Auxiliary functions for pair-wise preference network learning.protected voidbatchBackPropagate(int[][] pairMap, float[][] pairWeight)protected int[][]batchFeedForward(RankList rl)protected voidclearNeuronOutputs()protected float[][]computePairWeight(int[][] pairMap, RankList rl)protected voidconnect(int sourceLayer, int sourceNeuron, int targetLayer, int targetNeuron)RankercreateNew()protected doublecrossEntropy(double o1, double o2, double targetValue)protected voidestimateLoss()doubleeval(DataPoint p)voidinit()Main public functionsprotected RankListinternalReorder(RankList rl)voidlearn()voidloadFromString(java.lang.String fullText)java.lang.Stringmodel()java.lang.Stringname()protected voidprintNetworkConfig()FOR DEBUGGING PURPOSE ONLYvoidprintParameters()protected voidprintWeightVector()protected voidpropagate(int i)protected voidrestoreBestModelOnValidation()protected voidsaveBestModelOnValidation()Model validationprotected voidsetInputOutput(int nInput, int nOutput)Setting up the Neural Networkprotected voidsetInputOutput(int nInput, int nOutput, int nType)java.lang.StringtoString()protected voidwire()-
Methods inherited from class ciir.umass.edu.learning.Ranker
copy, getFeatures, getScoreOnTrainingData, getScoreOnValidationData, PRINT, PRINT, PRINT_MEMORY_USAGE, PRINTLN, PRINTLN, PRINTTIME, rank, rank, save, setFeatures, setMetricScorer, setTrainingSet, setValidationSet
-
-
-
-
Field Detail
-
nIteration
public static int nIteration
-
nHiddenLayer
public static int nHiddenLayer
-
nHiddenNodePerLayer
public static int nHiddenNodePerLayer
-
learningRate
public static double learningRate
-
layers
protected java.util.List<Layer> layers
-
inputLayer
protected Layer inputLayer
-
outputLayer
protected Layer outputLayer
-
bestModelOnValidation
protected java.util.List<java.util.List<java.lang.Double>> bestModelOnValidation
-
totalPairs
protected int totalPairs
-
misorderedPairs
protected int misorderedPairs
-
error
protected double error
-
lastError
protected double lastError
-
straightLoss
protected int straightLoss
-
-
Constructor Detail
-
RankNet
public RankNet()
-
RankNet
public RankNet(java.util.List<RankList> samples, int[] features, MetricScorer scorer)
-
-
Method Detail
-
setInputOutput
protected void setInputOutput(int nInput, int nOutput)Setting up the Neural Network
-
setInputOutput
protected void setInputOutput(int nInput, int nOutput, int nType)
-
addHiddenLayer
protected void addHiddenLayer(int size)
-
wire
protected void wire()
-
connect
protected void connect(int sourceLayer, int sourceNeuron, int targetLayer, int targetNeuron)
-
addInput
protected void addInput(DataPoint p)
Auxiliary functions for pair-wise preference network learning.
-
propagate
protected void propagate(int i)
-
batchFeedForward
protected int[][] batchFeedForward(RankList rl)
-
batchBackPropagate
protected void batchBackPropagate(int[][] pairMap, float[][] pairWeight)
-
clearNeuronOutputs
protected void clearNeuronOutputs()
-
computePairWeight
protected float[][] computePairWeight(int[][] pairMap, RankList rl)
-
saveBestModelOnValidation
protected void saveBestModelOnValidation()
Model validation
-
restoreBestModelOnValidation
protected void restoreBestModelOnValidation()
-
crossEntropy
protected double crossEntropy(double o1, double o2, double targetValue)
-
estimateLoss
protected void estimateLoss()
-
loadFromString
public void loadFromString(java.lang.String fullText)
- Specified by:
loadFromStringin classRanker
-
printParameters
public void printParameters()
- Specified by:
printParametersin classRanker
-
printNetworkConfig
protected void printNetworkConfig()
FOR DEBUGGING PURPOSE ONLY
-
printWeightVector
protected void printWeightVector()
-
-