001/** 002 * Copyright (c) 2009, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024package org.cleartk.ml.sigmoid; 025 026import java.util.logging.Logger; 027 028/** 029 * <br> 030 * Copyright (c) 2009, Regents of the University of Colorado <br> 031 * All rights reserved. 032 * <p> 033 * This class implements an algorithm to fit a sigmoid function to the output of an SVM classifier. 034 * The algorithm is the one introduced by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng (who were 035 * in turn extending work by J. Platt), and this implementation is a direct translation of their 036 * pseudo-code as presented in 037 * 038 * Lin, Lin, Weng. A note on Platt's probabilistic outputs for support vector machines. In Machine 039 * Learning, vol. 68, pp. 267-276, 2007. 040 * 041 * @author Philipp G. Wetzler 042 */ 043public class LinWengPlatt { 044 045 public static Sigmoid fit(double[] decisionValues, boolean[] labels) throws ConvergenceFailure { 046 047 assert (decisionValues.length == labels.length); 048 049 int nPlus = 0; 050 for (boolean l : labels) 051 if (l) 052 nPlus += 1; 053 int nMinus = labels.length - nPlus; 054 055 int maxIterations = 100; 056 double minimumStepsize = 1e-10; 057 double sigma = 1e-12; 058 059 double hiTarget = (nPlus + 1.0) / (nPlus + 2.0); 060 double loTarget = 1 / (nMinus + 2.0); 061 int n = nMinus + nPlus; 062 063 double t[] = new double[n]; 064 for (int i = 0; i < n; i++) { 065 if (labels[i]) 066 t[i] = hiTarget; 067 else 068 t[i] = loTarget; 069 } 070 071 double a = 0.0; 072 double b = Math.log((nMinus + 1.0) / (nPlus + 1.0)); 073 double f = 0.0; 074 075 for (int i = 0; i < n; i++) { 076 double fApB = decisionValues[i] * a + b; 077 if (fApB >= 0) 078 f += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 079 else 080 f += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 081 } 082 083 int iterations; 084 for (iterations = 0; iterations < maxIterations; iterations++) { 085 double h11 = sigma; 086 double h22 = sigma; 087 double h21 = 0.0; 088 double g1 = 0.0; 089 double g2 = 0.0; 090 091 for (int i = 0; i < n; i++) { 092 double fApB = decisionValues[i] * a + b; 093 double p, q; 094 if (fApB >= 0) { 095 p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB)); 096 q = 1.0 / (1.0 + Math.exp(-fApB)); 097 } else { 098 p = 1.0 / (1.0 + Math.exp(fApB)); 099 q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); 100 } 101 double d2 = p * q; 102 h11 += decisionValues[i] * decisionValues[i] * d2; 103 h22 += d2; 104 h21 += decisionValues[i] * d2; 105 double d1 = t[i] - p; 106 g1 += decisionValues[i] * d1; 107 g2 += d1; 108 } 109 110 if (Math.abs(g1) < 1e-5 && Math.abs(g2) < 1e-5) 111 break; 112 113 double det = h11 * h22 - h21 * h21; 114 double dA = -(h22 * g1 - h21 * g2) / det; 115 double dB = -(-h21 * g1 + h11 * g2) / det; 116 double gd = g1 * dA + g2 * dB; 117 double stepsize = 1; 118 119 while (stepsize >= minimumStepsize) { 120 double newA = a + stepsize * dA; 121 double newB = b + stepsize * dB; 122 double newf = 0.0; 123 124 for (int i = 1; i < n; i++) { 125 double fApB = decisionValues[i] * newA + newB; 126 if (fApB >= 0) 127 newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 128 else 129 newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 130 } 131 132 if (newf < f + 0.0001 * stepsize * gd) { 133 a = newA; 134 b = newB; 135 f = newf; 136 break; 137 } else { 138 stepsize /= 2.0; 139 } 140 } 141 142 if (stepsize < minimumStepsize) { 143 Logger logger = Logger.getLogger(LinWengPlatt.class.getName()); 144 logger.fine("line search failure"); 145 break; 146 } 147 } 148 149 if (iterations >= maxIterations) 150 throw new ConvergenceFailure("Reaching maximum iterations"); 151 152 return new Sigmoid(a, b); 153 } 154 155 public static class ConvergenceFailure extends Exception { 156 157 private static final long serialVersionUID = -7570320408478887106L; 158 159 public ConvergenceFailure(String message) { 160 super(message); 161 } 162 163 } 164 165}