001/** 002 * Copyright (c) 2011, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024package org.cleartk.eval.util; 025 026import java.util.HashMap; 027import java.util.Map; 028import java.util.SortedSet; 029import java.util.TreeSet; 030 031import com.google.common.base.Objects; 032import com.google.common.collect.HashMultiset; 033import com.google.common.collect.Multiset; 034import com.google.common.collect.Ordering; 035 036/** 037 * This data structure provides an easy way to build and output a confusion matrix. A confusion 038 * matrix is a two dimensional table with a row and table for each class. Each element in the matrix 039 * shows the number of test examples for which the actual class is the row and the predicted class 040 * is the column. Display of this matrix is useful for identifying when a system is confusing two 041 * classes 042 * 043 * <br> 044 * For more info @see <a href="http://en.wikipedia.org/wiki/Confusion_matrix">The wikipedia page on 045 * Confusion Matrices</a> 046 * 047 * <br> 048 * Copyright (c) 2011, Regents of the University of Colorado <br> 049 * All rights reserved. 050 * 051 * @author Lee Becker 052 * 053 * @param <T> 054 * The data type used to represent the class labels 055 */ 056 057public class ConfusionMatrix<T extends Comparable<? super T>> { 058 private Map<T, Multiset<T>> matrix; 059 060 private SortedSet<T> classes; 061 062 /** 063 * Creates an empty confusion Matrix 064 */ 065 public ConfusionMatrix() { 066 this.matrix = new HashMap<T, Multiset<T>>(); 067 this.classes = new TreeSet<T>(Ordering.natural().nullsFirst()); 068 } 069 070 /** 071 * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix. 072 */ 073 public ConfusionMatrix(ConfusionMatrix<T> other) { 074 this(); 075 this.add(other); 076 } 077 078 /** 079 * Increments the entry specified by actual and predicted by one. 080 */ 081 public void add(T actual, T predicted) { 082 add(actual, predicted, 1); 083 } 084 085 /** 086 * Increments the entry specified by actual and predicted by count. 087 */ 088 public void add(T actual, T predicted, int count) { 089 if (matrix.containsKey(actual)) { 090 matrix.get(actual).add(predicted, count); 091 } else { 092 Multiset<T> counts = HashMultiset.create(); 093 counts.add(predicted, count); 094 matrix.put(actual, counts); 095 } 096 097 classes.add(actual); 098 classes.add(predicted); 099 } 100 101 /** 102 * Adds the entries from another confusion matrix to this one. 103 */ 104 public void add(ConfusionMatrix<T> other) { 105 for (T actual : other.matrix.keySet()) { 106 Multiset<T> counts = other.matrix.get(actual); 107 for (T predicted : counts.elementSet()) { 108 int count = counts.count(predicted); 109 this.add(actual, predicted, count); 110 } 111 } 112 } 113 114 /** 115 * Gives the set of all classes in the confusion matrix. 116 */ 117 public SortedSet<T> getClasses() { 118 return classes; 119 } 120 121 /** 122 * Gives the count of the number of times the "predicted" class was predicted for the "actual" 123 * class. 124 */ 125 public int getCount(T actual, T predicted) { 126 if (!matrix.containsKey(actual)) { 127 return 0; 128 } else { 129 return matrix.get(actual).count(predicted); 130 } 131 } 132 133 /** 134 * Computes the total number of times the class was predicted by the classifier. 135 */ 136 public int getPredictedTotal(T predicted) { 137 int total = 0; 138 for (T actual : classes) { 139 total += getCount(actual, predicted); 140 } 141 return total; 142 } 143 144 /** 145 * Computes the total number of times the class actually appeared in the data. 146 */ 147 public int getActualTotal(T actual) { 148 if (!matrix.containsKey(actual)) { 149 return 0; 150 } else { 151 int total = 0; 152 for (T elem : matrix.get(actual).elementSet()) { 153 total += matrix.get(actual).count(elem); 154 } 155 return total; 156 } 157 } 158 159 @Override 160 public String toString() { 161 return Objects.toStringHelper(this).add("matrix", this.matrix).toString(); 162 } 163 164 /** 165 * Outputs the ConfusionMatrix as comma-separated values for easy import into spreadsheets 166 */ 167 public String toCSV() { 168 StringBuilder builder = new StringBuilder(); 169 170 // Header Row 171 builder.append(",,Predicted Class,\n"); 172 173 // Predicted Classes Header Row 174 builder.append(",,"); 175 for (T predicted : classes) { 176 builder.append(String.format("%s,", predicted)); 177 } 178 builder.append("Total\n"); 179 180 // Data Rows 181 String firstColumnLabel = "Actual Class,"; 182 for (T actual : classes) { 183 builder.append(firstColumnLabel); 184 firstColumnLabel = ","; 185 builder.append(String.format("%s,", actual)); 186 187 for (T predicted : classes) { 188 builder.append(getCount(actual, predicted)); 189 builder.append(","); 190 } 191 // Actual Class Totals Column 192 builder.append(getActualTotal(actual)); 193 builder.append("\n"); 194 } 195 196 // Predicted Class Totals Row 197 builder.append(",Total,"); 198 for (T predicted : classes) { 199 builder.append(getPredictedTotal(predicted)); 200 builder.append(","); 201 } 202 builder.append("\n"); 203 204 return builder.toString(); 205 } 206 207 /** 208 * Outputs Confusion Matrix in an HTML table. Cascading Style Sheets (CSS) can control the table's 209 * appearance by defining the empty-space, actual-count-header, predicted-class-header, and 210 * count-element classes. For example 211 * 212 * @return html string 213 */ 214 public String toHTML() { 215 StringBuilder builder = new StringBuilder(); 216 217 int numClasses = classes.size(); 218 // Header Row 219 builder.append("<table>\n"); 220 builder.append("<tr><th class=\"empty-space\" colspan=\"2\" rowspan=\"2\">"); 221 builder.append(String.format( 222 "<th class=\"predicted-class-header\" colspan=\"%d\">Predicted Class</th></tr>\n", 223 numClasses + 1)); 224 225 // Predicted Classes Header Row 226 builder.append("<tr>"); 227 // builder.append("<th></th><th></th>"); 228 for (T predicted : classes) { 229 builder.append("<th class=\"predicted-class-header\">"); 230 builder.append(predicted); 231 builder.append("</th>"); 232 } 233 builder.append("<th class=\"predicted-class-header\">Total</th>"); 234 builder.append("</tr>\n"); 235 236 // Data Rows 237 String firstColumnLabel = String.format( 238 "<tr><th class=\"actual-class-header\" rowspan=\"%d\">Actual Class</th>", 239 numClasses + 1); 240 for (T actual : classes) { 241 builder.append(firstColumnLabel); 242 firstColumnLabel = "<tr>"; 243 builder.append(String.format("<th class=\"actual-class-header\" >%s</th>", actual)); 244 245 for (T predicted : classes) { 246 builder.append("<td class=\"count-element\">"); 247 builder.append(getCount(actual, predicted)); 248 builder.append("</td>"); 249 } 250 251 // Actual Class Totals Column 252 builder.append("<td class=\"count-element\">"); 253 builder.append(getActualTotal(actual)); 254 builder.append("</td>"); 255 builder.append("</tr>\n"); 256 } 257 258 // Predicted Class Totals Row 259 builder.append("<tr><th class=\"actual-class-header\">Total</th>"); 260 for (T predicted : classes) { 261 builder.append("<td class=\"count-element\">"); 262 builder.append(getPredictedTotal(predicted)); 263 builder.append("</td>"); 264 } 265 builder.append("<td class=\"empty-space\"></td>\n"); 266 builder.append("</tr>\n"); 267 builder.append("</table>\n"); 268 269 return builder.toString(); 270 } 271 272 public static void main(String[] args) { 273 ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<String>(); 274 275 confusionMatrix.add("a", "a", 88); 276 confusionMatrix.add("a", "b", 10); 277 // confusionMatrix.add("a", "c", 2); 278 confusionMatrix.add("b", "a", 14); 279 confusionMatrix.add("b", "b", 40); 280 confusionMatrix.add("b", "c", 6); 281 confusionMatrix.add("c", "a", 18); 282 confusionMatrix.add("c", "b", 10); 283 confusionMatrix.add("c", "c", 12); 284 285 ConfusionMatrix<String> confusionMatrix2 = new ConfusionMatrix<String>(confusionMatrix); 286 confusionMatrix2.add(confusionMatrix); 287 System.out.println(confusionMatrix2.toHTML()); 288 System.out.println(confusionMatrix2.toCSV()); 289 } 290}