001/** 
002 * Copyright (c) 2009, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024package org.cleartk.ml.sigmoid;
025
026import java.util.logging.Logger;
027
028/**
029 * <br>
030 * Copyright (c) 2009, Regents of the University of Colorado <br>
031 * All rights reserved.
032 * <p>
033 * This class implements an algorithm to fit a sigmoid function to the output of an SVM classifier.
034 * The algorithm is the one introduced by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng (who were
035 * in turn extending work by J. Platt), and this implementation is a direct translation of their
036 * pseudo-code as presented in
037 * 
038 * Lin, Lin, Weng. A note on Platt's probabilistic outputs for support vector machines. In Machine
039 * Learning, vol. 68, pp. 267-276, 2007.
040 * 
041 * @author Philipp G. Wetzler
042 */
043public class LinWengPlatt {
044
045  public static Sigmoid fit(double[] decisionValues, boolean[] labels) throws ConvergenceFailure {
046
047    assert (decisionValues.length == labels.length);
048
049    int nPlus = 0;
050    for (boolean l : labels)
051      if (l)
052        nPlus += 1;
053    int nMinus = labels.length - nPlus;
054
055    int maxIterations = 100;
056    double minimumStepsize = 1e-10;
057    double sigma = 1e-12;
058
059    double hiTarget = (nPlus + 1.0) / (nPlus + 2.0);
060    double loTarget = 1 / (nMinus + 2.0);
061    int n = nMinus + nPlus;
062
063    double t[] = new double[n];
064    for (int i = 0; i < n; i++) {
065      if (labels[i])
066        t[i] = hiTarget;
067      else
068        t[i] = loTarget;
069    }
070
071    double a = 0.0;
072    double b = Math.log((nMinus + 1.0) / (nPlus + 1.0));
073    double f = 0.0;
074
075    for (int i = 0; i < n; i++) {
076      double fApB = decisionValues[i] * a + b;
077      if (fApB >= 0)
078        f += t[i] * fApB + Math.log(1 + Math.exp(-fApB));
079      else
080        f += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB));
081    }
082
083    int iterations;
084    for (iterations = 0; iterations < maxIterations; iterations++) {
085      double h11 = sigma;
086      double h22 = sigma;
087      double h21 = 0.0;
088      double g1 = 0.0;
089      double g2 = 0.0;
090
091      for (int i = 0; i < n; i++) {
092        double fApB = decisionValues[i] * a + b;
093        double p, q;
094        if (fApB >= 0) {
095          p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
096          q = 1.0 / (1.0 + Math.exp(-fApB));
097        } else {
098          p = 1.0 / (1.0 + Math.exp(fApB));
099          q = Math.exp(fApB) / (1.0 + Math.exp(fApB));
100        }
101        double d2 = p * q;
102        h11 += decisionValues[i] * decisionValues[i] * d2;
103        h22 += d2;
104        h21 += decisionValues[i] * d2;
105        double d1 = t[i] - p;
106        g1 += decisionValues[i] * d1;
107        g2 += d1;
108      }
109
110      if (Math.abs(g1) < 1e-5 && Math.abs(g2) < 1e-5)
111        break;
112
113      double det = h11 * h22 - h21 * h21;
114      double dA = -(h22 * g1 - h21 * g2) / det;
115      double dB = -(-h21 * g1 + h11 * g2) / det;
116      double gd = g1 * dA + g2 * dB;
117      double stepsize = 1;
118
119      while (stepsize >= minimumStepsize) {
120        double newA = a + stepsize * dA;
121        double newB = b + stepsize * dB;
122        double newf = 0.0;
123
124        for (int i = 1; i < n; i++) {
125          double fApB = decisionValues[i] * newA + newB;
126          if (fApB >= 0)
127            newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB));
128          else
129            newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB));
130        }
131
132        if (newf < f + 0.0001 * stepsize * gd) {
133          a = newA;
134          b = newB;
135          f = newf;
136          break;
137        } else {
138          stepsize /= 2.0;
139        }
140      }
141
142      if (stepsize < minimumStepsize) {
143        Logger logger = Logger.getLogger(LinWengPlatt.class.getName());
144        logger.fine("line search failure");
145        break;
146      }
147    }
148
149    if (iterations >= maxIterations)
150      throw new ConvergenceFailure("Reaching maximum iterations");
151
152    return new Sigmoid(a, b);
153  }
154
155  public static class ConvergenceFailure extends Exception {
156
157    private static final long serialVersionUID = -7570320408478887106L;
158
159    public ConvergenceFailure(String message) {
160      super(message);
161    }
162
163  }
164
165}