001/* 002 * Copyright (c) 2012, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024package org.cleartk.eval; 025 026import java.io.File; 027import java.util.ArrayList; 028import java.util.List; 029 030import org.apache.uima.cas.CAS; 031import org.apache.uima.collection.CollectionReader; 032 033/** 034 * A base class for training, testing and cross-validation of models. 035 * 036 * <br> 037 * Copyright (c) 2012, Regents of the University of Colorado <br> 038 * All rights reserved. 039 * 040 * @param <ITEM_TYPE> 041 * The type of items that define the train and test data in 042 * {@link #trainAndTest(List, List)} and {@link #crossValidation(List, int)}, and that are 043 * used to create {@link CollectionReader}s in {@link #getCollectionReader(List)}. A common 044 * choice for this parameter when working with data on a filesystem is {@link File}. 045 * @param <STATS_TYPE> 046 * The type of statistics object that will be returned by testing methods. A common choice 047 * for this parameter is {@link AnnotationStatistics}. 048 * @author Steven Bethard 049 */ 050public abstract class Evaluation_ImplBase<ITEM_TYPE, STATS_TYPE> { 051 protected File baseDirectory; 052 053 /** 054 * Create an evaluation that will write all auxiliary files to the given directory. 055 * 056 * @param baseDirectory 057 * The directory for all evaluation files. 058 */ 059 public Evaluation_ImplBase(File baseDirectory) { 060 this.baseDirectory = baseDirectory; 061 } 062 063 /** 064 * Train a model on one set of items and test it on another. 065 * 066 * @param trainItems 067 * The items on which to train. 068 * @param testItems 069 * The items on which to test. 070 * @return The statistics that result from testing the model. 071 */ 072 public STATS_TYPE trainAndTest(List<ITEM_TYPE> trainItems, List<ITEM_TYPE> testItems) 073 throws Exception { 074 File subDirectory = new File(this.baseDirectory, "train_and_test"); 075 subDirectory.mkdirs(); 076 this.train(this.getCollectionReader(trainItems), subDirectory); 077 return this.test(this.getCollectionReader(testItems), subDirectory); 078 } 079 080 /** 081 * Run a cross-validation. 082 * 083 * Splits the items into nFolds approximately equal subsets, and for each subset, training a 084 * classifier on all the remaining data and testing on the subset. 085 * 086 * @param items 087 * The items on which to train and test. Good machine learning practice requires that 088 * these items come only from the training data, and not from the test data. 089 * @param nFolds 090 * The number of subsets into which the items should be split. Note that the number of 091 * folds may not be larger than the number of items. 092 * @return The statistics that result from testing the model, one for each fold. 093 */ 094 public List<STATS_TYPE> crossValidation(List<ITEM_TYPE> items, int nFolds) throws Exception { 095 if (nFolds > items.size()) { 096 String message = "Cannot have %d folds with only %d items"; 097 throw new IllegalArgumentException(String.format(message, nFolds, items.size())); 098 } 099 100 List<STATS_TYPE> stats = new ArrayList<STATS_TYPE>(); 101 for (int fold = 0; fold < nFolds; ++fold) { 102 File subDirectory = new File(this.baseDirectory, "fold_" + fold); 103 subDirectory.mkdirs(); 104 List<ITEM_TYPE> trainItems = this.selectFoldTrainItems(items, nFolds, fold); 105 List<ITEM_TYPE> testItems = this.selectFoldTestItems(items, nFolds, fold); 106 this.train(this.getCollectionReader(trainItems), subDirectory); 107 stats.add(this.test(this.getCollectionReader(testItems), subDirectory)); 108 } 109 return stats; 110 } 111 112 /** 113 * Determines which items should be used for training in one fold of a cross-validation. 114 * 115 * The default implementation includes all items except for every (nFolds)th item, but this may be 116 * overridden in subclasses. 117 * 118 * @param items 119 * The full list of training items. 120 * @param nFolds 121 * The total number of folds in this cross validation. 122 * @param fold 123 * The index of the fold (0 <= fold < nFolds) whose training items are to be selected. 124 * @return The items that should be used for training. 125 */ 126 protected List<ITEM_TYPE> selectFoldTrainItems(List<ITEM_TYPE> items, int nFolds, int fold) { 127 List<ITEM_TYPE> trainItems = new ArrayList<ITEM_TYPE>(); 128 for (int i = 0; i < items.size(); ++i) { 129 if (i % nFolds != fold) { 130 trainItems.add(items.get(i)); 131 } 132 } 133 return trainItems; 134 } 135 136 /** 137 * Determines which items should be used for testing in one fold of a cross-validation. 138 * 139 * The default implementation includes every (nFolds)th item, but this may be overridden in 140 * subclasses. 141 * 142 * @param items 143 * The full list of training items. 144 * @param nFolds 145 * The total number of folds in this cross validation. 146 * @param fold 147 * The index of the fold (0 <= fold < nFolds) whose test items are to be selected. 148 * @return The items that should be used for testing. 149 */ 150 protected List<ITEM_TYPE> selectFoldTestItems(List<ITEM_TYPE> items, int nFolds, int fold) { 151 List<ITEM_TYPE> testItems = new ArrayList<ITEM_TYPE>(); 152 for (int i = 0; i < items.size(); ++i) { 153 if (i % nFolds == fold) { 154 testItems.add(items.get(i)); 155 } 156 } 157 return testItems; 158 } 159 160 /** 161 * Creates a {@link CollectionReader} from the given items. 162 * 163 * This method is called in {@link #trainAndTest(List, List)} and 164 * {@link #crossValidation(List, int)} to create readers both for the training data and for the 165 * testing data. 166 * 167 * @param items 168 * Items from the training, test or cross-validation sets. 169 * @return A {@link CollectionReader} that produces {@link CAS}es for the items. 170 */ 171 protected abstract CollectionReader getCollectionReader(List<ITEM_TYPE> items) throws Exception; 172 173 /** 174 * Trains a model on a set of training data. 175 * 176 * @param collectionReader 177 * The data on which the model should be trained. 178 * @param directory 179 * The directory in which any model files should be written. 180 */ 181 protected abstract void train(CollectionReader collectionReader, File directory) throws Exception; 182 183 /** 184 * Evaluates a model on a set of testing data. 185 * 186 * @param collectionReader 187 * The data on which the model should be tested. 188 * @param directory 189 * The directory in which any model files should be written. This method may safely 190 * assume that {@link #train(CollectionReader, File)} was called on this same directory 191 * before {@link #test(CollectionReader, File)} was called. 192 * @return The statistics that result from testing the model 193 */ 194 protected abstract STATS_TYPE test(CollectionReader collectionReader, File directory) 195 throws Exception; 196 197}