001/** 
002 * Copyright (c) 2011, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024package org.cleartk.eval.util;
025
026import java.util.HashMap;
027import java.util.Map;
028import java.util.SortedSet;
029import java.util.TreeSet;
030
031import com.google.common.base.Objects;
032import com.google.common.collect.HashMultiset;
033import com.google.common.collect.Multiset;
034import com.google.common.collect.Ordering;
035
036/**
037 * This data structure provides an easy way to build and output a confusion matrix. A confusion
038 * matrix is a two dimensional table with a row and table for each class. Each element in the matrix
039 * shows the number of test examples for which the actual class is the row and the predicted class
040 * is the column. Display of this matrix is useful for identifying when a system is confusing two
041 * classes
042 * 
043 * <br>
044 * For more info @see <a href="http://en.wikipedia.org/wiki/Confusion_matrix">The wikipedia page on
045 * Confusion Matrices</a>
046 * 
047 * <br>
048 * Copyright (c) 2011, Regents of the University of Colorado <br>
049 * All rights reserved.
050 * 
051 * @author Lee Becker
052 * 
053 * @param <T>
054 *          The data type used to represent the class labels
055 */
056
057public class ConfusionMatrix<T extends Comparable<? super T>> {
058  private Map<T, Multiset<T>> matrix;
059
060  private SortedSet<T> classes;
061
062  /**
063   * Creates an empty confusion Matrix
064   */
065  public ConfusionMatrix() {
066    this.matrix = new HashMap<T, Multiset<T>>();
067    this.classes = new TreeSet<T>(Ordering.natural().nullsFirst());
068  }
069
070  /**
071   * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix.
072   */
073  public ConfusionMatrix(ConfusionMatrix<T> other) {
074    this();
075    this.add(other);
076  }
077
078  /**
079   * Increments the entry specified by actual and predicted by one.
080   */
081  public void add(T actual, T predicted) {
082    add(actual, predicted, 1);
083  }
084
085  /**
086   * Increments the entry specified by actual and predicted by count.
087   */
088  public void add(T actual, T predicted, int count) {
089    if (matrix.containsKey(actual)) {
090      matrix.get(actual).add(predicted, count);
091    } else {
092      Multiset<T> counts = HashMultiset.create();
093      counts.add(predicted, count);
094      matrix.put(actual, counts);
095    }
096
097    classes.add(actual);
098    classes.add(predicted);
099  }
100
101  /**
102   * Adds the entries from another confusion matrix to this one.
103   */
104  public void add(ConfusionMatrix<T> other) {
105    for (T actual : other.matrix.keySet()) {
106      Multiset<T> counts = other.matrix.get(actual);
107      for (T predicted : counts.elementSet()) {
108        int count = counts.count(predicted);
109        this.add(actual, predicted, count);
110      }
111    }
112  }
113
114  /**
115   * Gives the set of all classes in the confusion matrix.
116   */
117  public SortedSet<T> getClasses() {
118    return classes;
119  }
120
121  /**
122   * Gives the count of the number of times the "predicted" class was predicted for the "actual"
123   * class.
124   */
125  public int getCount(T actual, T predicted) {
126    if (!matrix.containsKey(actual)) {
127      return 0;
128    } else {
129      return matrix.get(actual).count(predicted);
130    }
131  }
132
133  /**
134   * Computes the total number of times the class was predicted by the classifier.
135   */
136  public int getPredictedTotal(T predicted) {
137    int total = 0;
138    for (T actual : classes) {
139      total += getCount(actual, predicted);
140    }
141    return total;
142  }
143
144  /**
145   * Computes the total number of times the class actually appeared in the data.
146   */
147  public int getActualTotal(T actual) {
148    if (!matrix.containsKey(actual)) {
149      return 0;
150    } else {
151      int total = 0;
152      for (T elem : matrix.get(actual).elementSet()) {
153        total += matrix.get(actual).count(elem);
154      }
155      return total;
156    }
157  }
158
159  @Override
160  public String toString() {
161    return Objects.toStringHelper(this).add("matrix", this.matrix).toString();
162  }
163
164  /**
165   * Outputs the ConfusionMatrix as comma-separated values for easy import into spreadsheets
166   */
167  public String toCSV() {
168    StringBuilder builder = new StringBuilder();
169
170    // Header Row
171    builder.append(",,Predicted Class,\n");
172
173    // Predicted Classes Header Row
174    builder.append(",,");
175    for (T predicted : classes) {
176      builder.append(String.format("%s,", predicted));
177    }
178    builder.append("Total\n");
179
180    // Data Rows
181    String firstColumnLabel = "Actual Class,";
182    for (T actual : classes) {
183      builder.append(firstColumnLabel);
184      firstColumnLabel = ",";
185      builder.append(String.format("%s,", actual));
186
187      for (T predicted : classes) {
188        builder.append(getCount(actual, predicted));
189        builder.append(",");
190      }
191      // Actual Class Totals Column
192      builder.append(getActualTotal(actual));
193      builder.append("\n");
194    }
195
196    // Predicted Class Totals Row
197    builder.append(",Total,");
198    for (T predicted : classes) {
199      builder.append(getPredictedTotal(predicted));
200      builder.append(",");
201    }
202    builder.append("\n");
203
204    return builder.toString();
205  }
206
207  /**
208   * Outputs Confusion Matrix in an HTML table. Cascading Style Sheets (CSS) can control the table's
209   * appearance by defining the empty-space, actual-count-header, predicted-class-header, and
210   * count-element classes. For example
211   * 
212   * @return html string
213   */
214  public String toHTML() {
215    StringBuilder builder = new StringBuilder();
216
217    int numClasses = classes.size();
218    // Header Row
219    builder.append("<table>\n");
220    builder.append("<tr><th class=\"empty-space\" colspan=\"2\" rowspan=\"2\">");
221    builder.append(String.format(
222        "<th class=\"predicted-class-header\" colspan=\"%d\">Predicted Class</th></tr>\n",
223        numClasses + 1));
224
225    // Predicted Classes Header Row
226    builder.append("<tr>");
227    // builder.append("<th></th><th></th>");
228    for (T predicted : classes) {
229      builder.append("<th class=\"predicted-class-header\">");
230      builder.append(predicted);
231      builder.append("</th>");
232    }
233    builder.append("<th class=\"predicted-class-header\">Total</th>");
234    builder.append("</tr>\n");
235
236    // Data Rows
237    String firstColumnLabel = String.format(
238        "<tr><th class=\"actual-class-header\" rowspan=\"%d\">Actual Class</th>",
239        numClasses + 1);
240    for (T actual : classes) {
241      builder.append(firstColumnLabel);
242      firstColumnLabel = "<tr>";
243      builder.append(String.format("<th class=\"actual-class-header\" >%s</th>", actual));
244
245      for (T predicted : classes) {
246        builder.append("<td class=\"count-element\">");
247        builder.append(getCount(actual, predicted));
248        builder.append("</td>");
249      }
250
251      // Actual Class Totals Column
252      builder.append("<td class=\"count-element\">");
253      builder.append(getActualTotal(actual));
254      builder.append("</td>");
255      builder.append("</tr>\n");
256    }
257
258    // Predicted Class Totals Row
259    builder.append("<tr><th class=\"actual-class-header\">Total</th>");
260    for (T predicted : classes) {
261      builder.append("<td class=\"count-element\">");
262      builder.append(getPredictedTotal(predicted));
263      builder.append("</td>");
264    }
265    builder.append("<td class=\"empty-space\"></td>\n");
266    builder.append("</tr>\n");
267    builder.append("</table>\n");
268
269    return builder.toString();
270  }
271
272  public static void main(String[] args) {
273    ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<String>();
274
275    confusionMatrix.add("a", "a", 88);
276    confusionMatrix.add("a", "b", 10);
277    // confusionMatrix.add("a", "c", 2);
278    confusionMatrix.add("b", "a", 14);
279    confusionMatrix.add("b", "b", 40);
280    confusionMatrix.add("b", "c", 6);
281    confusionMatrix.add("c", "a", 18);
282    confusionMatrix.add("c", "b", 10);
283    confusionMatrix.add("c", "c", 12);
284
285    ConfusionMatrix<String> confusionMatrix2 = new ConfusionMatrix<String>(confusionMatrix);
286    confusionMatrix2.add(confusionMatrix);
287    System.out.println(confusionMatrix2.toHTML());
288    System.out.println(confusionMatrix2.toCSV());
289  }
290}