001/** 002 * Copyright (c) 2012, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024package org.cleartk.ml.feature.selection; 025 026import java.io.BufferedReader; 027import java.io.BufferedWriter; 028import java.io.File; 029import java.io.FileReader; 030import java.io.FileWriter; 031import java.io.IOException; 032import java.net.URI; 033import java.util.ArrayList; 034import java.util.Collection; 035import java.util.List; 036import java.util.Locale; 037import java.util.Map; 038import java.util.Set; 039 040import org.apache.uima.jcas.JCas; 041import org.apache.uima.jcas.tcas.Annotation; 042import org.cleartk.ml.Feature; 043import org.cleartk.ml.Instance; 044import org.cleartk.ml.feature.extractor.CleartkExtractorException; 045import org.cleartk.ml.feature.extractor.FeatureExtractor1; 046import org.cleartk.ml.feature.selection.MutualInformationFeatureSelectionExtractor.CombineScoreMethod.CombineScoreFunction; 047import org.cleartk.ml.feature.selection.MutualInformationFeatureSelectionExtractor.MutualInformationStats.ComputeFeatureScore; 048import org.cleartk.ml.feature.transform.TransformableFeature; 049 050import com.google.common.base.Function; 051import com.google.common.base.Joiner; 052import com.google.common.collect.Collections2; 053import com.google.common.collect.HashBasedTable; 054import com.google.common.collect.HashMultiset; 055import com.google.common.collect.Lists; 056import com.google.common.collect.Maps; 057import com.google.common.collect.Multiset; 058import com.google.common.collect.Ordering; 059import com.google.common.collect.Table; 060 061/** 062 * <br> 063 * Copyright (c) 2012, Regents of the University of Colorado <br> 064 * All rights reserved. 065 * <p> 066 * 067 * Selects features via mutual information statistics between the features extracted from its 068 * sub-extractor and the outcome values they are paired with in classification instances. 069 * 070 * @author Lee Becker 071 * 072 */ 073public class MutualInformationFeatureSelectionExtractor<OUTCOME_T, FOCUS_T extends Annotation> 074 extends FeatureSelectionExtractor<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> { 075 076 /** 077 * Specifies how scores for each outcome should be combined/aggregated into a single score 078 */ 079 public static enum CombineScoreMethod { 080 AVERAGE, // Average mutual information across all classes and take features with k-largest 081 // values 082 MAX; // Take highest mutual information value for each class 083 // MERGE, // Take k-largest mutual information values for each class and merge into a single 084 // collection - currently omitted because it requires a different extraction flow 085 086 public abstract static class CombineScoreFunction<OUTCOME_T> implements 087 Function<Map<OUTCOME_T, Double>, Double> { 088 } 089 090 public static class AverageScores<OUTCOME_T> extends CombineScoreFunction<OUTCOME_T> { 091 @Override 092 public Double apply(Map<OUTCOME_T, Double> input) { 093 Collection<Double> scores = input.values(); 094 int size = scores.size(); 095 double total = 0; 096 097 for (Double score : scores) { 098 total += score; 099 } 100 return total / size; 101 } 102 } 103 104 public static class MaxScores<OUTCOME_T> extends CombineScoreFunction<OUTCOME_T> { 105 @Override 106 public Double apply(Map<OUTCOME_T, Double> input) { 107 return Ordering.natural().max(input.values()); 108 } 109 } 110 } 111 112 /** 113 * Helper class for aggregating and computing mutual information statistics 114 */ 115 public static class MutualInformationStats<OUTCOME_T> { 116 protected Multiset<OUTCOME_T> classCounts; 117 118 protected Table<String, OUTCOME_T, Integer> classConditionalCounts; 119 120 protected double smoothingCount; 121 122 public MutualInformationStats(double smoothingCount) { 123 this.classCounts = HashMultiset.<OUTCOME_T> create(); 124 this.classConditionalCounts = HashBasedTable.<String, OUTCOME_T, Integer> create(); 125 this.smoothingCount += smoothingCount; 126 } 127 128 public void update(String featureName, OUTCOME_T outcome, int occurrences) { 129 Integer count = this.classConditionalCounts.get(featureName, outcome); 130 if (count == null) { 131 count = 0; 132 } 133 this.classConditionalCounts.put(featureName, outcome, count + occurrences); 134 this.classCounts.add(outcome, occurrences); 135 } 136 137 public double mutualInformation(String featureName, OUTCOME_T outcome) { 138 // notation index of 0 means false, 1 mean true 139 int[] featureCounts = new int[2]; 140 int[] outcomeCounts = new int[2]; 141 int[][] featureOutcomeCounts = new int[2][2]; 142 143 int n = this.classCounts.size(); 144 featureCounts[1] = this.sum(this.classConditionalCounts.row(featureName).values()); 145 featureCounts[0] = n - featureCounts[1]; 146 outcomeCounts[1] = this.classCounts.count(outcome); 147 outcomeCounts[0] = n - outcomeCounts[1]; 148 149 featureOutcomeCounts[1][1] = this.classConditionalCounts.contains(featureName, outcome) 150 ? this.classConditionalCounts.get(featureName, outcome) 151 : 0; 152 featureOutcomeCounts[1][0] = featureCounts[1] - featureOutcomeCounts[1][1]; 153 featureOutcomeCounts[0][1] = outcomeCounts[1] - featureOutcomeCounts[1][1]; 154 featureOutcomeCounts[0][0] = n - featureCounts[1] - outcomeCounts[1] 155 + featureOutcomeCounts[1][1]; 156 157 double information = 0.0; 158 for (int nFeature = 0; nFeature <= 1; nFeature++) { 159 for (int nOutcome = 0; nOutcome <= 1; nOutcome++) { 160 featureOutcomeCounts[nFeature][nOutcome] += smoothingCount; 161 information += (double) featureOutcomeCounts[nFeature][nOutcome] 162 / (double) n 163 * Math.log(((double) n * featureOutcomeCounts[nFeature][nOutcome]) 164 / ((double) featureCounts[nFeature] * outcomeCounts[nOutcome])); 165 } 166 } 167 168 return information; 169 } 170 171 private int sum(Collection<Integer> values) { 172 int total = 0; 173 for (int v : values) { 174 total += v; 175 } 176 return total; 177 } 178 179 public void save(URI outputURI) throws IOException { 180 File out = new File(outputURI); 181 BufferedWriter writer = null; 182 writer = new BufferedWriter(new FileWriter(out)); 183 184 // Write out header 185 writer.append("Mutual Information Data\n"); 186 writer.append("Feature\t"); 187 writer.append(Joiner.on("\t").join(this.classConditionalCounts.columnKeySet())); 188 writer.append("\n"); 189 190 // Write out Mutual Information data 191 for (String featureName : this.classConditionalCounts.rowKeySet()) { 192 writer.append(featureName); 193 for (OUTCOME_T outcome : this.classConditionalCounts.columnKeySet()) { 194 writer.append("\t"); 195 writer.append(String.format( 196 Locale.ROOT, 197 "%f", 198 this.mutualInformation(featureName, outcome))); 199 } 200 writer.append("\n"); 201 } 202 writer.append("\n"); 203 writer.append(this.classConditionalCounts.toString()); 204 writer.close(); 205 } 206 207 public ComputeFeatureScore<OUTCOME_T> getScoreFunction(CombineScoreMethod combineScoreMethod) { 208 return new ComputeFeatureScore<OUTCOME_T>(this, combineScoreMethod); 209 } 210 211 public static class ComputeFeatureScore<OUTCOME_T> implements Function<String, Double> { 212 213 private MutualInformationStats<OUTCOME_T> stats; 214 215 private CombineScoreFunction<OUTCOME_T> combineScoreFunction; 216 217 public ComputeFeatureScore( 218 MutualInformationStats<OUTCOME_T> stats, 219 CombineScoreMethod combineMeasureType) { 220 this.stats = stats; 221 switch (combineMeasureType) { 222 case AVERAGE: 223 this.combineScoreFunction = new CombineScoreMethod.AverageScores<OUTCOME_T>(); 224 case MAX: 225 this.combineScoreFunction = new CombineScoreMethod.MaxScores<OUTCOME_T>(); 226 } 227 228 } 229 230 @Override 231 public Double apply(String featureName) { 232 Set<OUTCOME_T> outcomes = stats.classConditionalCounts.columnKeySet(); 233 Map<OUTCOME_T, Double> featureOutcomeMI = Maps.newHashMap(); 234 for (OUTCOME_T outcome : outcomes) { 235 featureOutcomeMI.put(outcome, stats.mutualInformation(featureName, outcome)); 236 } 237 return this.combineScoreFunction.apply(featureOutcomeMI); 238 } 239 240 } 241 242 } 243 244 public String nameFeature(Feature feature) { 245 return (feature.getValue() instanceof Number) ? feature.getName() : feature.getName() + ":" 246 + feature.getValue(); 247 } 248 249 protected boolean isTrained; 250 251 private MutualInformationStats<OUTCOME_T> mutualInfoStats; 252 253 private FeatureExtractor1<FOCUS_T> subExtractor; 254 255 private int numFeatures; 256 257 private CombineScoreMethod combineScoreMethod; 258 259 private List<String> selectedFeatures; 260 261 private double smoothingCount; 262 263 public MutualInformationFeatureSelectionExtractor( 264 String name, 265 FeatureExtractor1<FOCUS_T> extractor) { 266 super(name); 267 this.init(extractor, CombineScoreMethod.MAX, 1.0, 10); 268 } 269 270 public MutualInformationFeatureSelectionExtractor( 271 String name, 272 FeatureExtractor1<FOCUS_T> extractor, 273 int numFeatures) { 274 super(name); 275 this.init(extractor, CombineScoreMethod.MAX, 1.0, numFeatures); 276 } 277 278 public MutualInformationFeatureSelectionExtractor( 279 String name, 280 FeatureExtractor1<FOCUS_T> extractor, 281 CombineScoreMethod combineMeasureType, 282 double smoothingCount, 283 int numFeatures) { 284 super(name); 285 this.init(extractor, combineMeasureType, smoothingCount, numFeatures); 286 } 287 288 private void init( 289 FeatureExtractor1<FOCUS_T> extractor, 290 CombineScoreMethod method, 291 double smoothCount, 292 int n) { 293 this.subExtractor = extractor; 294 this.combineScoreMethod = method; 295 this.smoothingCount = smoothCount; 296 this.numFeatures = n; 297 } 298 299 @Override 300 public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException { 301 302 List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation); 303 List<Feature> result = new ArrayList<Feature>(); 304 if (this.isTrained) { 305 // Filter out selected features 306 result.addAll(Collections2.filter(extracted, this)); 307 } else { 308 // We haven't trained this extractor yet, so just mark the existing features 309 // for future modification, by creating one uber-container feature 310 result.add(new TransformableFeature(this.name, extracted)); 311 } 312 313 return result; 314 } 315 316 @Override 317 public void train(Iterable<Instance<OUTCOME_T>> instances) { 318 // aggregate statistics for all features and classes 319 this.mutualInfoStats = new MutualInformationStats<OUTCOME_T>(this.smoothingCount); 320 321 for (Instance<OUTCOME_T> instance : instances) { 322 OUTCOME_T outcome = instance.getOutcome(); 323 for (Feature feature : instance.getFeatures()) { 324 if (this.isTransformable(feature)) { 325 for (Feature untransformedFeature : ((TransformableFeature) feature).getFeatures()) { 326 mutualInfoStats.update(this.nameFeature(untransformedFeature), outcome, 1); 327 } 328 } 329 } 330 } 331 // Compute mutual information score for each feature 332 Set<String> featureNames = mutualInfoStats.classConditionalCounts.rowKeySet(); 333 334 this.selectedFeatures = Ordering.natural().onResultOf( 335 this.mutualInfoStats.getScoreFunction(this.combineScoreMethod)).reverse().immutableSortedCopy( 336 featureNames); 337 this.isTrained = true; 338 } 339 340 @Override 341 public void save(URI uri) throws IOException { 342 if (!this.isTrained) { 343 throw new IOException("MutualInformationFeatureExtractor: Cannot save before training."); 344 } 345 File out = new File(uri); 346 BufferedWriter writer = new BufferedWriter(new FileWriter(out)); 347 writer.append("CombineScoreType\t"); 348 writer.append(this.combineScoreMethod.toString()); 349 writer.append("\n"); 350 351 ComputeFeatureScore<OUTCOME_T> computeScore = this.mutualInfoStats.getScoreFunction(this.combineScoreMethod); 352 for (String feature : this.selectedFeatures) { 353 writer.append(String.format(Locale.ROOT, "%s\t%f\n", feature, computeScore.apply(feature))); 354 } 355 356 writer.close(); 357 358 } 359 360 @Override 361 public void load(URI uri) throws IOException { 362 this.selectedFeatures = Lists.newArrayList(); 363 File in = new File(uri); 364 BufferedReader reader = new BufferedReader(new FileReader(in)); 365 366 // First line specifies the combine utility type 367 this.combineScoreMethod = CombineScoreMethod.valueOf(reader.readLine().split("\\t")[1]); 368 369 // The rest of the lines are feature + selection scores 370 String line = null; 371 int n = 0; 372 while ((line = reader.readLine()) != null && n < this.numFeatures) { 373 String[] featureValuePair = line.split("\\t"); 374 this.selectedFeatures.add(featureValuePair[0]); 375 n++; 376 } 377 378 reader.close(); 379 this.isTrained = true; 380 } 381 382 @Override 383 public boolean apply(Feature feature) { 384 return this.selectedFeatures.contains(this.nameFeature(feature)); 385 } 386 387 public final List<String> getSelectedFeatures() { 388 return this.selectedFeatures; 389 } 390 391}