001/* 
002 * Copyright (c) 2012, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024package org.cleartk.eval;
025
026import java.io.Serializable;
027import java.util.ArrayList;
028import java.util.Collection;
029import java.util.Collections;
030import java.util.HashMap;
031import java.util.HashSet;
032import java.util.List;
033import java.util.Map;
034import java.util.Set;
035
036import org.apache.uima.cas.Feature;
037import org.apache.uima.jcas.cas.TOP;
038import org.apache.uima.jcas.tcas.Annotation;
039import org.cleartk.eval.util.ConfusionMatrix;
040
041import com.google.common.base.Function;
042import com.google.common.base.Objects;
043import com.google.common.base.Objects.ToStringHelper;
044import com.google.common.collect.HashMultiset;
045import com.google.common.collect.Multiset;
046
047/**
048 * Stores statistics for comparing {@link Annotation}s extracted by a system to gold
049 * {@link Annotation}s.
050 * 
051 * <br>
052 * Copyright (c) 2012, Regents of the University of Colorado <br>
053 * All rights reserved.
054 * 
055 * @author Steven Bethard
056 */
057public class AnnotationStatistics<OUTCOME_TYPE extends Comparable<? super OUTCOME_TYPE>> implements
058    Serializable {
059
060  private static final long serialVersionUID = 1L;
061
062  private Multiset<OUTCOME_TYPE> referenceOutcomes;
063
064  private Multiset<OUTCOME_TYPE> predictedOutcomes;
065
066  private Multiset<OUTCOME_TYPE> correctOutcomes;
067
068  private ConfusionMatrix<OUTCOME_TYPE> confusionMatrix;
069
070  /**
071   * Creates a {@link Function} that converts an {@link Annotation} into a hashable representation
072   * of its begin and end offsets.
073   * 
074   * The {@link Function} created by this method is suitable for passing to the first
075   * {@link Function} argument of {@link #add(Collection, Collection, Function, Function)}.
076   */
077  public static <ANNOTATION_TYPE extends Annotation> Function<ANNOTATION_TYPE, Span> annotationToSpan() {
078    return new Function<ANNOTATION_TYPE, Span>() {
079      @Override
080      public Span apply(ANNOTATION_TYPE annotation) {
081        return new Span(annotation);
082      }
083    };
084  }
085
086  /**
087   * Creates a {@link Function} that extracts a feature value from a {@link TOP}.
088   * 
089   * The {@link Function} created by this method is suitable for passing to the second
090   * {@link Function} argument of {@link #add(Collection, Collection, Function, Function)}.
091   * 
092   * @param featureName
093   *          The name of the feature whose value is to be extracted.
094   */
095  public static <ANNOTATION_TYPE extends TOP> Function<ANNOTATION_TYPE, String> annotationToFeatureValue(
096      final String featureName) {
097    return new Function<ANNOTATION_TYPE, String>() {
098      @Override
099      public String apply(ANNOTATION_TYPE annotation) {
100        Feature feature = annotation.getType().getFeatureByBaseName(featureName);
101        return annotation.getFeatureValueAsString(feature);
102      }
103    };
104  }
105
106  /**
107   * Creates a {@link Function} that always returns null.
108   * 
109   * This may be useful when only the span of the offset is important, but you still need to pass in
110   * the final argument of {@link #add(Collection, Collection, Function, Function)}.
111   */
112  public static <ANNOTATION_TYPE, OUTCOME_TYPE> Function<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToNull() {
113    return new Function<ANNOTATION_TYPE, OUTCOME_TYPE>() {
114      @Override
115      public OUTCOME_TYPE apply(ANNOTATION_TYPE annotation) {
116        return null;
117      }
118    };
119  }
120
121  /**
122   * Add all statistics together.
123   * 
124   * This is often useful for combining individual fold statistics that result from methods like
125   * {@link Evaluation_ImplBase#crossValidation(List, int)}.
126   * 
127   * @param statistics
128   *          The sequence of statistics that should be combined.
129   * @return The combination of all the individual statistics.
130   */
131  public static <OUTCOME_TYPE extends Comparable<? super OUTCOME_TYPE>> AnnotationStatistics<OUTCOME_TYPE> addAll(
132      Iterable<AnnotationStatistics<OUTCOME_TYPE>> statistics) {
133    AnnotationStatistics<OUTCOME_TYPE> result = new AnnotationStatistics<OUTCOME_TYPE>();
134    for (AnnotationStatistics<OUTCOME_TYPE> item : statistics) {
135      result.addAll(item);
136    }
137    return result;
138  }
139
140  /**
141   * Create an AnnotationStatistics that compares {@link Annotation}s based on their begin and end
142   * offsets, plus a {@link Feature} of the {@link Annotation} that represents the outcome or label.
143   */
144  public AnnotationStatistics() {
145    this.referenceOutcomes = HashMultiset.create();
146    this.predictedOutcomes = HashMultiset.create();
147    this.correctOutcomes = HashMultiset.create();
148    this.confusionMatrix = new ConfusionMatrix<OUTCOME_TYPE>();
149  }
150
151  /**
152   * Update the statistics, comparing the reference annotations to the predicted annotations.
153   * 
154   * Annotations are considered to match if they have the same character offsets in the text. All
155   * outcomes (e.g. as returned in {@link #confusions()}) will be <code>null</code>.
156   * 
157   * @param referenceAnnotations
158   *          The reference annotations, typically identified by humans.
159   * @param predictedAnnotations
160   *          The predicted annotations, typically identified by a model.
161   */
162  public <ANNOTATION_TYPE extends Annotation> void add(
163      Collection<? extends ANNOTATION_TYPE> referenceAnnotations,
164      Collection<? extends ANNOTATION_TYPE> predictedAnnotations) {
165    this.add(
166        referenceAnnotations,
167        predictedAnnotations,
168        AnnotationStatistics.<ANNOTATION_TYPE> annotationToSpan(),
169        AnnotationStatistics.<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToNull());
170  }
171
172  /**
173   * Update the statistics, comparing the reference annotations to the predicted annotations.
174   * 
175   * Annotations are considered to match if they have the same span (according to
176   * {@code annotationToSpan}) and if they have the same outcome (according to
177   * {@code annotationToOutcome}).
178   * 
179   * @param referenceAnnotations
180   *          The reference annotations, typically identified by humans.
181   * @param predictedAnnotations
182   *          The predicted annotations, typically identified by a model.
183   * @param annotationToSpan
184   *          A function that defines how to convert an annotation into a hashable object that
185   *          represents the span of that annotation. The {@link #annotationToSpan()} method
186   *          provides an example function that could be used here.
187   * @param annotationToOutcome
188   *          A function that defines how to convert an annotation into an object that represents
189   *          the outcome (or "label") assigned to that annotation. The
190   *          {@link #annotationToFeatureValue(String)} method provides a sample function that could
191   *          be used here.
192   */
193  public <ANNOTATION_TYPE, SPAN_TYPE> void add(
194      Collection<? extends ANNOTATION_TYPE> referenceAnnotations,
195      Collection<? extends ANNOTATION_TYPE> predictedAnnotations,
196      Function<ANNOTATION_TYPE, SPAN_TYPE> annotationToSpan,
197      Function<ANNOTATION_TYPE, OUTCOME_TYPE> annotationToOutcome) {
198
199    // map gold spans to their outcomes
200    Map<SPAN_TYPE, OUTCOME_TYPE> referenceSpanOutcomes = new HashMap<SPAN_TYPE, OUTCOME_TYPE>();
201    for (ANNOTATION_TYPE ann : referenceAnnotations) {
202      referenceSpanOutcomes.put(annotationToSpan.apply(ann), annotationToOutcome.apply(ann));
203    }
204
205    // map system spans to their outcomes
206    Map<SPAN_TYPE, OUTCOME_TYPE> predictedSpanOutcomes = new HashMap<SPAN_TYPE, OUTCOME_TYPE>();
207    for (ANNOTATION_TYPE ann : predictedAnnotations) {
208      predictedSpanOutcomes.put(annotationToSpan.apply(ann), annotationToOutcome.apply(ann));
209    }
210
211    // update the gold and system outcomes
212    this.referenceOutcomes.addAll(referenceSpanOutcomes.values());
213    this.predictedOutcomes.addAll(predictedSpanOutcomes.values());
214
215    // determine the outcomes that were correct
216    Set<SPAN_TYPE> intersection = new HashSet<SPAN_TYPE>();
217    intersection.addAll(referenceSpanOutcomes.keySet());
218    intersection.retainAll(predictedSpanOutcomes.keySet());
219    for (SPAN_TYPE span : intersection) {
220      OUTCOME_TYPE goldOutcome = referenceSpanOutcomes.get(span);
221      OUTCOME_TYPE systemOutcome = predictedSpanOutcomes.get(span);
222      if (Objects.equal(goldOutcome, systemOutcome)) {
223        this.correctOutcomes.add(goldOutcome);
224      }
225    }
226
227    // update the confusion matrix
228    Set<SPAN_TYPE> union = new HashSet<SPAN_TYPE>();
229    union.addAll(referenceSpanOutcomes.keySet());
230    union.addAll(predictedSpanOutcomes.keySet());
231    for (SPAN_TYPE span : union) {
232      OUTCOME_TYPE goldOutcome = referenceSpanOutcomes.get(span);
233      OUTCOME_TYPE systemOutcome = predictedSpanOutcomes.get(span);
234      this.confusionMatrix.add(goldOutcome, systemOutcome);
235    }
236  }
237
238  /**
239   * Adds all the statistics collected by another AnnotationStatistics to this one.
240   * 
241   * @param that
242   *          The other statistics that should be added to this one.
243   */
244  public void addAll(AnnotationStatistics<OUTCOME_TYPE> that) {
245    this.referenceOutcomes.addAll(that.referenceOutcomes);
246    this.predictedOutcomes.addAll(that.predictedOutcomes);
247    this.correctOutcomes.addAll(that.correctOutcomes);
248    this.confusionMatrix.add(that.confusionMatrix);
249  }
250
251  public int countCorrectOutcomes() {
252    return this.correctOutcomes.size();
253  }
254
255  public int countCorrectOutcomes(OUTCOME_TYPE outcome) {
256    return this.correctOutcomes.count(outcome);
257  }
258
259  public int countPredictedOutcomes() {
260    return this.predictedOutcomes.size();
261  }
262
263  public int countPredictedOutcomes(OUTCOME_TYPE outcome) {
264    return this.predictedOutcomes.count(outcome);
265  }
266
267  public int countReferenceOutcomes() {
268    return this.referenceOutcomes.size();
269  }
270
271  public int countReferenceOutcomes(OUTCOME_TYPE outcome) {
272    return this.referenceOutcomes.count(outcome);
273  }
274
275  public int countFalseNegatives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) {
276    int numReferenceOutcomes = this.countReferenceOutcomes();
277    int numPredictedOutcomes = this.countPredictedOutcomes();
278    if (numReferenceOutcomes != numPredictedOutcomes) {
279      throw new IllegalStateException(
280          String.format(
281              "Expected number equal number of references outcomes and predicted outcomes.  Had reference outcomes=%d, predicted outcomes=%d",
282              numReferenceOutcomes,
283              numPredictedOutcomes,
284              this.countPredictedOutcomes()));
285    }
286    int totalFalseNegatives = 0;
287    for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) {
288      totalFalseNegatives += this.countReferenceOutcomes(positiveOutcome)
289          - this.countCorrectOutcomes(positiveOutcome);
290    }
291    return totalFalseNegatives;
292  }
293
294  public int countFalsePositives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) {
295    int numReferenceOutcomes = this.countReferenceOutcomes();
296    int numPredictedOutcomes = this.countPredictedOutcomes();
297    if (numReferenceOutcomes != numPredictedOutcomes) {
298      throw new IllegalStateException(
299          String.format(
300              "Expected number equal number of references outcomes and predicted outcomes.  Had reference outcomes=%d, predicted outcomes=%d",
301              numReferenceOutcomes,
302              numPredictedOutcomes,
303              this.countPredictedOutcomes()));
304    }
305    int totalFalsePositives = 0;
306    for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) {
307      totalFalsePositives += this.countPredictedOutcomes(positiveOutcome)
308          - this.countCorrectOutcomes(positiveOutcome);
309    }
310
311    return totalFalsePositives;
312  }
313
314  public int countTrueNegatives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) {
315    int numReferenceOutcomes = this.countReferenceOutcomes();
316    int numPredictedOutcomes = this.countPredictedOutcomes();
317    if (numReferenceOutcomes != numPredictedOutcomes) {
318      throw new IllegalStateException(
319          String.format(
320              "Expected number equal number of references outcomes and predicted outcomes.  Had reference outcomes=%d, predicted outcomes=%d",
321              numReferenceOutcomes,
322              numPredictedOutcomes,
323              this.countPredictedOutcomes()));
324    }
325    int totalTrueNegatives = this.countCorrectOutcomes();
326
327    for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) {
328      totalTrueNegatives -= this.countCorrectOutcomes(positiveOutcome);
329    }
330
331    return totalTrueNegatives;
332
333  }
334
335  public int countTruePositives(@SuppressWarnings("unchecked") OUTCOME_TYPE... positiveOutcomes) {
336    int numReferenceOutcomes = this.countReferenceOutcomes();
337    int numPredictedOutcomes = this.countPredictedOutcomes();
338    if (numReferenceOutcomes != numPredictedOutcomes) {
339      throw new IllegalStateException(
340          String.format(
341              "Expected number equal number of references outcomes and predicted outcomes.  Had reference outcomes=%d, predicted outcomes=%d",
342              numReferenceOutcomes,
343              numPredictedOutcomes,
344              this.countPredictedOutcomes()));
345    }
346
347    int totalTruePositives = 0;
348    for (OUTCOME_TYPE positiveOutcome : positiveOutcomes) {
349      totalTruePositives += this.countCorrectOutcomes(positiveOutcome);
350    }
351    return totalTruePositives;
352  }
353
354  /**
355   * Returns the {@link ConfusionMatrix} tabulating reference outcomes matched to predicted
356   * outcomes.
357   * 
358   * @return The confusion matrix.
359   */
360  public ConfusionMatrix<OUTCOME_TYPE> confusions() {
361    return this.confusionMatrix;
362  }
363
364  public double precision() {
365    int nSystem = this.countPredictedOutcomes();
366    return nSystem == 0 ? 1.0 : ((double) this.countCorrectOutcomes()) / nSystem;
367  }
368
369  public double precision(OUTCOME_TYPE outcome) {
370    int nSystem = this.countPredictedOutcomes(outcome);
371    return nSystem == 0 ? 1.0 : ((double) this.countCorrectOutcomes(outcome)) / nSystem;
372  }
373
374  public double recall() {
375    int nGold = this.countReferenceOutcomes();
376    return nGold == 0 ? 1.0 : ((double) this.countCorrectOutcomes()) / nGold;
377  }
378
379  public double recall(OUTCOME_TYPE outcome) {
380    int nGold = this.countReferenceOutcomes(outcome);
381    return nGold == 0 ? 1.0 : ((double) this.countCorrectOutcomes(outcome)) / nGold;
382  }
383
384  public double f(double beta) {
385    double p = this.precision();
386    double r = this.recall();
387    double num = (1 + beta * beta) * p * r;
388    double den = (beta * beta * p) + r;
389    return den == 0.0 ? 0.0 : num / den;
390  }
391
392  public double f(double beta, OUTCOME_TYPE outcome) {
393    double p = this.precision(outcome);
394    double r = this.recall(outcome);
395    double num = (1 + beta * beta) * p * r;
396    double den = (beta * beta * p) + r;
397    return den == 0.0 ? 0.0 : num / den;
398  }
399
400  public double f1() {
401    return this.f(1.0);
402  }
403
404  public double f1(OUTCOME_TYPE outcome) {
405    return f(1.0, outcome);
406  }
407
408  @Override
409  public String toString() {
410    StringBuilder result = new StringBuilder();
411    result.append("P\tR\tF1\t#gold\t#system\t#correct\n");
412    result.append(String.format(
413        "%.3f\t%.3f\t%.3f\t%d\t%d\t%d\tOVERALL\n",
414        this.precision(),
415        this.recall(),
416        this.f1(),
417        this.referenceOutcomes.size(),
418        this.predictedOutcomes.size(),
419        this.correctOutcomes.size()));
420    List<OUTCOME_TYPE> outcomes = new ArrayList<OUTCOME_TYPE>(this.referenceOutcomes.elementSet());
421    if (outcomes.size() > 1) {
422      Collections.sort(outcomes);
423      for (OUTCOME_TYPE outcome : outcomes) {
424        result.append(String.format(
425            "%.3f\t%.3f\t%.3f\t%d\t%d\t%d\t%s\n",
426            this.precision(outcome),
427            this.recall(outcome),
428            this.f1(outcome),
429            this.referenceOutcomes.count(outcome),
430            this.predictedOutcomes.count(outcome),
431            this.correctOutcomes.count(outcome),
432            outcome));
433      }
434    }
435    return result.toString();
436  }
437
438  private static class Span {
439
440    public int end;
441
442    public int begin;
443
444    public Span(Annotation annotation) {
445      this.begin = annotation.getBegin();
446      this.end = annotation.getEnd();
447    }
448
449    @Override
450    public int hashCode() {
451      return Objects.hashCode(this.begin, this.end);
452    }
453
454    @Override
455    public boolean equals(Object obj) {
456      if (!this.getClass().equals(obj.getClass())) {
457        return false;
458      }
459      Span that = (Span) obj;
460      return this.begin == that.begin && this.end == that.end;
461    }
462
463    @Override
464    public String toString() {
465      ToStringHelper helper = Objects.toStringHelper(this);
466      helper.add("begin", this.begin);
467      helper.add("end", this.end);
468      return helper.toString();
469    }
470  }
471}