001/* 002 * Copyright (c) 2012, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024package org.cleartk.eval; 025 026import java.io.Serializable; 027import java.util.ArrayList; 028import java.util.Collection; 029import java.util.Collections; 030import java.util.HashMap; 031import java.util.HashSet; 032import java.util.List; 033import java.util.Map; 034import java.util.Set; 035 036import org.apache.uima.cas.Feature; 037import org.apache.uima.jcas.cas.TOP; 038import org.apache.uima.jcas.tcas.Annotation; 039import org.cleartk.eval.util.ConfusionMatrix; 040 041import com.google.common.base.Function; 042import com.google.common.base.Objects; 043import com.google.common.base.Objects.ToStringHelper; 044import com.google.common.collect.HashMultiset; 045import com.google.common.collect.Multiset; 046 047/** 048 * Stores statistics for comparing {@link Annotation}s extracted by a system to gold 049 * {@link Annotation}s. 050 * 051 * <br> 052 * Copyright (c) 2012, Regents of the University of Colorado <br> 053 * All rights reserved. 054 * 055 * @author Steven Bethard 056 */ 057public class AnnotationStatistics<OUTCOME_TYPE extends Comparable<? super OUTCOME_TYPE>> implements 058 Serializable { 059 060 private static final long serialVersionUID = 1L; 061 062 private Multiset<OUTCOME_TYPE> referenceOutcomes; 063 064 private Multiset<OUTCOME_TYPE> predictedOutcomes; 065 066 private Multiset<OUTCOME_TYPE> correctOutcomes; 067 068 private ConfusionMatrix<OUTCOME_TYPE> confusionMatrix; 069 070 /** 071 * Creates a {@link Function} that converts an {@link Annotation} into a hashable representation 072 * of its begin and end offsets. 073 * 074 * The {@link Function} created by this method is suitable for passing to the first 075 * {@link Function} argument of {@link #add(Collection, Collection, Function, Function)}. 076 */ 077 public static <ANNOTATION_TYPE extends Annotation> Function<ANNOTATION_TYPE, Span> annotationToSpan() { 078 return new Function<ANNOTATION_TYPE, Span>() { 079 @Override 080 public Span apply(ANNOTATION_TYPE annotation) { 081 return new Span(annotation); 082 } 083 }; 084 } 085 086 /** 087 * Creates a {@link Function} that extracts a feature value from a {@link TOP}. 088 * 089 * The {@link Function} created by this method is suitable for passing to the second 090 * {@link Function} argument of {@link #add(Collection, Collection, Function, Function)}. 091 * 092 * @param featureName 093 * The name of the feature whose value is to be extracted. 094 */ 095 public static <ANNOTATION_TYPE extends TOP> Function<ANNOTATION_TYPE, String> annotationToFeatureValue( 096 final String featureName) { 097 return new Function<ANNOTATION_TYPE, String>() { 098 @Override 099 public String apply(ANNOTATION_TYPE annotation) { 100 Feature feature = annotation.getType().getFeatureByBaseName(featureName); 101 return annotation.getFeatureValueAsString(feature); 102 } 103 }; 104 } 105 106 /** 107 * Creates a {@link Function} that always returns null. 108 * 109 * This may be useful when only the span of the offset is important, but you still need to pass in 110 * the final argument of {@link #add(Collection, Collection, Function, Function)}. 111 */ 112 public static <ANNOTATION_TYPE, OUTCOME_TYPE> Function<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToNull() { 113 return new Function<ANNOTATION_TYPE, OUTCOME_TYPE>() { 114 @Override 115 public OUTCOME_TYPE apply(ANNOTATION_TYPE annotation) { 116 return null; 117 } 118 }; 119 } 120 121 /** 122 * Add all statistics together. 123 * 124 * This is often useful for combining individual fold statistics that result from methods like 125 * {@link Evaluation_ImplBase#crossValidation(List, int)}. 126 * 127 * @param statistics 128 * The sequence of statistics that should be combined. 129 * @return The combination of all the individual statistics. 130 */ 131 public static <OUTCOME_TYPE extends Comparable<? super OUTCOME_TYPE>> AnnotationStatistics<OUTCOME_TYPE> addAll( 132 Iterable<AnnotationStatistics<OUTCOME_TYPE>> statistics) { 133 AnnotationStatistics<OUTCOME_TYPE> result = new AnnotationStatistics<OUTCOME_TYPE>(); 134 for (AnnotationStatistics<OUTCOME_TYPE> item : statistics) { 135 result.addAll(item); 136 } 137 return result; 138 } 139 140 /** 141 * Create an AnnotationStatistics that compares {@link Annotation}s based on their begin and end 142 * offsets, plus a {@link Feature} of the {@link Annotation} that represents the outcome or label. 143 */ 144 public AnnotationStatistics() { 145 this.referenceOutcomes = HashMultiset.create(); 146 this.predictedOutcomes = HashMultiset.create(); 147 this.correctOutcomes = HashMultiset.create(); 148 this.confusionMatrix = new ConfusionMatrix<OUTCOME_TYPE>(); 149 } 150 151 /** 152 * Update the statistics, comparing the reference annotations to the predicted annotations. 153 * 154 * Annotations are considered to match if they have the same character offsets in the text. All 155 * outcomes (e.g. as returned in {@link #confusions()}) will be <code>null</code>. 156 * 157 * @param referenceAnnotations 158 * The reference annotations, typically identified by humans. 159 * @param predictedAnnotations 160 * The predicted annotations, typically identified by a model. 161 */ 162 public <ANNOTATION_TYPE extends Annotation> void add( 163 Collection<? extends ANNOTATION_TYPE> referenceAnnotations, 164 Collection<? extends ANNOTATION_TYPE> predictedAnnotations) { 165 this.add( 166 referenceAnnotations, 167 predictedAnnotations, 168 AnnotationStatistics.<ANNOTATION_TYPE> annotationToSpan(), 169 AnnotationStatistics.<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToNull()); 170 } 171 172 /** 173 * Update the statistics, comparing the reference annotations to the predicted annotations. 174 * 175 * Annotations are considered to match if they have the same span (according to 176 * {@code annotationToSpan}) and if they have the same outcome (according to 177 * {@code annotationToOutcome}). 178 * 179 * @param referenceAnnotations 180 * The reference annotations, typically identified by humans. 181 * @param predictedAnnotations 182 * The predicted annotations, typically identified by a model. 183 * @param annotationToSpan 184 * A function that defines how to convert an annotation into a hashable object that 185 * represents the span of that annotation. The {@link #annotationToSpan()} method 186 * provides an example function that could be used here. 187 * @param annotationToOutcome 188 * A function that defines how to convert an annotation into an object that represents 189 * the outcome (or "label") assigned to that annotation. The 190 * {@link #annotationToFeatureValue(String)} method provides a sample function that could 191 * be used here. 192 */ 193 public <ANNOTATION_TYPE, SPAN_TYPE> void add( 194 Collection<? extends ANNOTATION_TYPE> referenceAnnotations, 195 Collection<? extends ANNOTATION_TYPE> predictedAnnotations, 196 Function<ANNOTATION_TYPE, SPAN_TYPE> annotationToSpan, 197 Function<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToOutcome) { 198 199 // map gold spans to their outcomes 200 Map<SPAN_TYPE, OUTCOME_TYPE> referenceSpanOutcomes = new HashMap<SPAN_TYPE, OUTCOME_TYPE>(); 201 for (ANNOTATION_TYPE ann : referenceAnnotations) { 202 referenceSpanOutcomes.put(annotationToSpan.apply(ann), annotationToOutcome.apply(ann)); 203 } 204 205 // map system spans to their outcomes 206 Map<SPAN_TYPE, OUTCOME_TYPE> predictedSpanOutcomes = new HashMap<SPAN_TYPE, OUTCOME_TYPE>(); 207 for (ANNOTATION_TYPE ann : predictedAnnotations) { 208 predictedSpanOutcomes.put(annotationToSpan.apply(ann), annotationToOutcome.apply(ann)); 209 } 210 211 // update the gold and system outcomes 212 this.referenceOutcomes.addAll(referenceSpanOutcomes.values()); 213 this.predictedOutcomes.addAll(predictedSpanOutcomes.values()); 214 215 // determine the outcomes that were correct 216 Set<SPAN_TYPE> intersection = new HashSet<SPAN_TYPE>(); 217 intersection.addAll(referenceSpanOutcomes.keySet()); 218 intersection.retainAll(predictedSpanOutcomes.keySet()); 219 for (SPAN_TYPE span : intersection) { 220 OUTCOME_TYPE goldOutcome = referenceSpanOutcomes.get(span); 221 OUTCOME_TYPE systemOutcome = predictedSpanOutcomes.get(span); 222 if (Objects.equal(goldOutcome, systemOutcome)) { 223 this.correctOutcomes.add(goldOutcome); 224 } 225 } 226 227 // update the confusion matrix 228 Set<SPAN_TYPE> union = new HashSet<SPAN_TYPE>(); 229 union.addAll(referenceSpanOutcomes.keySet()); 230 union.addAll(predictedSpanOutcomes.keySet()); 231 for (SPAN_TYPE span : union) { 232 OUTCOME_TYPE goldOutcome = referenceSpanOutcomes.get(span); 233 OUTCOME_TYPE systemOutcome = predictedSpanOutcomes.get(span); 234 this.confusionMatrix.add(goldOutcome, systemOutcome); 235 } 236 } 237 238 /** 239 * Adds all the statistics collected by another AnnotationStatistics to this one. 240 * 241 * @param that 242 * The other statistics that should be added to this one. 243 */ 244 public void addAll(AnnotationStatistics<OUTCOME_TYPE> that) { 245 this.referenceOutcomes.addAll(that.referenceOutcomes); 246 this.predictedOutcomes.addAll(that.predictedOutcomes); 247 this.correctOutcomes.addAll(that.correctOutcomes); 248 this.confusionMatrix.add(that.confusionMatrix); 249 } 250 251 public int countCorrectOutcomes() { 252 return this.correctOutcomes.size(); 253 } 254 255 public int countCorrectOutcomes(OUTCOME_TYPE outcome) { 256 return this.correctOutcomes.count(outcome); 257 } 258 259 public int countPredictedOutcomes() { 260 return this.predictedOutcomes.size(); 261 } 262 263 public int countPredictedOutcomes(OUTCOME_TYPE outcome) { 264 return this.predictedOutcomes.count(outcome); 265 } 266 267 public int countReferenceOutcomes() { 268 return this.referenceOutcomes.size(); 269 } 270 271 public int countReferenceOutcomes(OUTCOME_TYPE outcome) { 272 return this.referenceOutcomes.count(outcome); 273 } 274 275 public int countFalseNegatives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) { 276 int numReferenceOutcomes = this.countReferenceOutcomes(); 277 int numPredictedOutcomes = this.countPredictedOutcomes(); 278 if (numReferenceOutcomes != numPredictedOutcomes) { 279 throw new IllegalStateException( 280 String.format( 281 "Expected number equal number of references outcomes and predicted outcomes. Had reference outcomes=%d, predicted outcomes=%d", 282 numReferenceOutcomes, 283 numPredictedOutcomes, 284 this.countPredictedOutcomes())); 285 } 286 int totalFalseNegatives = 0; 287 for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) { 288 totalFalseNegatives += this.countReferenceOutcomes(positiveOutcome) 289 - this.countCorrectOutcomes(positiveOutcome); 290 } 291 return totalFalseNegatives; 292 } 293 294 public int countFalsePositives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) { 295 int numReferenceOutcomes = this.countReferenceOutcomes(); 296 int numPredictedOutcomes = this.countPredictedOutcomes(); 297 if (numReferenceOutcomes != numPredictedOutcomes) { 298 throw new IllegalStateException( 299 String.format( 300 "Expected number equal number of references outcomes and predicted outcomes. Had reference outcomes=%d, predicted outcomes=%d", 301 numReferenceOutcomes, 302 numPredictedOutcomes, 303 this.countPredictedOutcomes())); 304 } 305 int totalFalsePositives = 0; 306 for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) { 307 totalFalsePositives += this.countPredictedOutcomes(positiveOutcome) 308 - this.countCorrectOutcomes(positiveOutcome); 309 } 310 311 return totalFalsePositives; 312 } 313 314 public int countTrueNegatives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) { 315 int numReferenceOutcomes = this.countReferenceOutcomes(); 316 int numPredictedOutcomes = this.countPredictedOutcomes(); 317 if (numReferenceOutcomes != numPredictedOutcomes) { 318 throw new IllegalStateException( 319 String.format( 320 "Expected number equal number of references outcomes and predicted outcomes. Had reference outcomes=%d, predicted outcomes=%d", 321 numReferenceOutcomes, 322 numPredictedOutcomes, 323 this.countPredictedOutcomes())); 324 } 325 int totalTrueNegatives = this.countCorrectOutcomes(); 326 327 for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) { 328 totalTrueNegatives -= this.countCorrectOutcomes(positiveOutcome); 329 } 330 331 return totalTrueNegatives; 332 333 } 334 335 public int countTruePositives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) { 336 int numReferenceOutcomes = this.countReferenceOutcomes(); 337 int numPredictedOutcomes = this.countPredictedOutcomes(); 338 if (numReferenceOutcomes != numPredictedOutcomes) { 339 throw new IllegalStateException( 340 String.format( 341 "Expected number equal number of references outcomes and predicted outcomes. Had reference outcomes=%d, predicted outcomes=%d", 342 numReferenceOutcomes, 343 numPredictedOutcomes, 344 this.countPredictedOutcomes())); 345 } 346 347 int totalTruePositives = 0; 348 for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) { 349 totalTruePositives += this.countCorrectOutcomes(positiveOutcome); 350 } 351 return totalTruePositives; 352 } 353 354 /** 355 * Returns the {@link ConfusionMatrix} tabulating reference outcomes matched to predicted 356 * outcomes. 357 * 358 * @return The confusion matrix. 359 */ 360 public ConfusionMatrix<OUTCOME_TYPE> confusions() { 361 return this.confusionMatrix; 362 } 363 364 public double precision() { 365 int nSystem = this.countPredictedOutcomes(); 366 return nSystem == 0 ? 1.0 : ((double) this.countCorrectOutcomes()) / nSystem; 367 } 368 369 public double precision(OUTCOME_TYPE outcome) { 370 int nSystem = this.countPredictedOutcomes(outcome); 371 return nSystem == 0 ? 1.0 : ((double) this.countCorrectOutcomes(outcome)) / nSystem; 372 } 373 374 public double recall() { 375 int nGold = this.countReferenceOutcomes(); 376 return nGold == 0 ? 1.0 : ((double) this.countCorrectOutcomes()) / nGold; 377 } 378 379 public double recall(OUTCOME_TYPE outcome) { 380 int nGold = this.countReferenceOutcomes(outcome); 381 return nGold == 0 ? 1.0 : ((double) this.countCorrectOutcomes(outcome)) / nGold; 382 } 383 384 public double f(double beta) { 385 double p = this.precision(); 386 double r = this.recall(); 387 double num = (1 + beta * beta) * p * r; 388 double den = (beta * beta * p) + r; 389 return den == 0.0 ? 0.0 : num / den; 390 } 391 392 public double f(double beta, OUTCOME_TYPE outcome) { 393 double p = this.precision(outcome); 394 double r = this.recall(outcome); 395 double num = (1 + beta * beta) * p * r; 396 double den = (beta * beta * p) + r; 397 return den == 0.0 ? 0.0 : num / den; 398 } 399 400 public double f1() { 401 return this.f(1.0); 402 } 403 404 public double f1(OUTCOME_TYPE outcome) { 405 return f(1.0, outcome); 406 } 407 408 @Override 409 public String toString() { 410 StringBuilder result = new StringBuilder(); 411 result.append("P\tR\tF1\t#gold\t#system\t#correct\n"); 412 result.append(String.format( 413 "%.3f\t%.3f\t%.3f\t%d\t%d\t%d\tOVERALL\n", 414 this.precision(), 415 this.recall(), 416 this.f1(), 417 this.referenceOutcomes.size(), 418 this.predictedOutcomes.size(), 419 this.correctOutcomes.size())); 420 List<OUTCOME_TYPE> outcomes = new ArrayList<OUTCOME_TYPE>(this.referenceOutcomes.elementSet()); 421 if (outcomes.size() > 1) { 422 Collections.sort(outcomes); 423 for (OUTCOME_TYPE outcome : outcomes) { 424 result.append(String.format( 425 "%.3f\t%.3f\t%.3f\t%d\t%d\t%d\t%s\n", 426 this.precision(outcome), 427 this.recall(outcome), 428 this.f1(outcome), 429 this.referenceOutcomes.count(outcome), 430 this.predictedOutcomes.count(outcome), 431 this.correctOutcomes.count(outcome), 432 outcome)); 433 } 434 } 435 return result.toString(); 436 } 437 438 private static class Span { 439 440 public int end; 441 442 public int begin; 443 444 public Span(Annotation annotation) { 445 this.begin = annotation.getBegin(); 446 this.end = annotation.getEnd(); 447 } 448 449 @Override 450 public int hashCode() { 451 return Objects.hashCode(this.begin, this.end); 452 } 453 454 @Override 455 public boolean equals(Object obj) { 456 if (!this.getClass().equals(obj.getClass())) { 457 return false; 458 } 459 Span that = (Span) obj; 460 return this.begin == that.begin && this.end == that.end; 461 } 462 463 @Override 464 public String toString() { 465 ToStringHelper helper = Objects.toStringHelper(this); 466 helper.add("begin", this.begin); 467 helper.add("end", this.end); 468 return helper.toString(); 469 } 470 } 471}