001/** 002 * Copyright (c) 2012, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024 025package org.cleartk.ml.feature.transform.extractor; 026 027import java.io.BufferedReader; 028import java.io.BufferedWriter; 029import java.io.File; 030import java.io.FileReader; 031import java.io.FileWriter; 032import java.io.IOException; 033import java.io.ObjectInputStream; 034import java.io.ObjectOutputStream; 035import java.io.Serializable; 036import java.net.URI; 037import java.util.ArrayList; 038import java.util.HashMap; 039import java.util.List; 040import java.util.Locale; 041import java.util.Map; 042 043import org.apache.uima.jcas.JCas; 044import org.apache.uima.jcas.tcas.Annotation; 045import org.cleartk.ml.Feature; 046import org.cleartk.ml.Instance; 047import org.cleartk.ml.feature.extractor.CleartkExtractorException; 048import org.cleartk.ml.feature.extractor.FeatureExtractor1; 049import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase; 050import org.cleartk.ml.feature.transform.TransformableFeature; 051 052/** 053 * Scales features produced by its subextractor to have mean=0, stddev=1 for a given feature 054 * <p> 055 * 056 * Copyright (c) 2012, Regents of the University of Colorado <br> 057 * All rights reserved. 058 * 059 * @author Lee Becker 060 * 061 */ 062public class ZeroMeanUnitStddevExtractor<OUTCOME_T, FOCUS_T extends Annotation> extends 063 OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> { 064 065 private FeatureExtractor1<FOCUS_T> subExtractor; 066 067 private boolean isTrained; 068 069 // This is read in after training for use in transformation 070 private Map<String, MeanStddevTuple> meanStddevMap; 071 072 public ZeroMeanUnitStddevExtractor(String name) { 073 this(name, null); 074 } 075 076 public ZeroMeanUnitStddevExtractor(String name, FeatureExtractor1<FOCUS_T> subExtractor) { 077 super(name); 078 this.subExtractor = subExtractor; 079 this.isTrained = false; 080 } 081 082 @Override 083 protected Feature transform(Feature feature) { 084 String featureName = feature.getName(); 085 MeanStddevTuple stats = this.meanStddevMap.get(featureName); 086 double value = ((Number) feature.getValue()).doubleValue(); 087 088 double zmus = 0.0d; 089 if (stats != null && stats.stddev > 0) { 090 zmus = (value - stats.mean) / stats.stddev; 091 return new Feature("ZMUS_" + featureName, zmus); 092 } else { 093 return null; 094 } 095 } 096 097 @Override 098 public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException { 099 100 List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation); 101 List<Feature> result = new ArrayList<Feature>(); 102 if (this.isTrained) { 103 // We have trained / loaded a ZMUS model, so now fix up the values 104 for (Feature feature : extracted) { 105 Feature transformedFeature = this.transform(feature); 106 if (transformedFeature != null) 107 result.add(transformedFeature); 108 } 109 } else { 110 // We haven't trained this extractor yet, so just mark the existing features 111 // for future modification, by creating one mega container feature 112 result.add(new TransformableFeature(this.name, extracted)); 113 } 114 115 return result; 116 } 117 118 @Override 119 public void train(Iterable<Instance<OUTCOME_T>> instances) { 120 Map<String, MeanVarianceRunningStat> featureStatsMap = new HashMap<String, MeanVarianceRunningStat>(); 121 122 // keep a running mean and standard deviation for all applicable features 123 for (Instance<OUTCOME_T> instance : instances) { 124 // Grab the matching zmus (zero mean, unit stddev) features from the set of all features in an 125 // instance 126 for (Feature feature : instance.getFeatures()) { 127 if (this.isTransformable(feature)) { 128 // ZMUS features contain a list of features, these are actually what get added 129 // to our document frequency map 130 for (Feature untransformedFeature : ((TransformableFeature) feature).getFeatures()) { 131 String featureName = untransformedFeature.getName(); 132 Object featureValue = untransformedFeature.getValue(); 133 if (featureValue instanceof Number) { 134 MeanVarianceRunningStat stats; 135 if (featureStatsMap.containsKey(featureName)) { 136 stats = featureStatsMap.get(featureName); 137 } else { 138 stats = new MeanVarianceRunningStat(); 139 featureStatsMap.put(featureName, stats); 140 } 141 stats.add(((Number) featureValue).doubleValue()); 142 } else { 143 throw new IllegalArgumentException("Cannot normalize non-numeric feature values"); 144 } 145 } 146 } 147 } 148 } 149 150 this.meanStddevMap = new HashMap<String, MeanStddevTuple>(); 151 for (Map.Entry<String, MeanVarianceRunningStat> entry : featureStatsMap.entrySet()) { 152 MeanVarianceRunningStat stats = entry.getValue(); 153 this.meanStddevMap.put(entry.getKey(), new MeanStddevTuple(stats.mean(), stats.stddev())); 154 } 155 156 this.isTrained = true; 157 } 158 159 @Override 160 public void save(URI zmusDataUri) throws IOException { 161 // Write out tab separated values: feature_name, mean, stddev 162 File out = new File(zmusDataUri); 163 BufferedWriter writer = null; 164 writer = new BufferedWriter(new FileWriter(out)); 165 166 for (Map.Entry<String, MeanStddevTuple> entry : this.meanStddevMap.entrySet()) { 167 MeanStddevTuple tuple = entry.getValue(); 168 writer.append(String.format( 169 Locale.ROOT, 170 "%s\t%f\t%f\n", 171 entry.getKey(), 172 tuple.mean, 173 tuple.stddev)); 174 } 175 writer.close(); 176 } 177 178 @Override 179 public void load(URI zmusDataUri) throws IOException { 180 // Reads in tab separated values (feature name, mean, stddev) 181 File in = new File(zmusDataUri); 182 BufferedReader reader = null; 183 this.meanStddevMap = new HashMap<String, MeanStddevTuple>(); 184 reader = new BufferedReader(new FileReader(in)); 185 String line = null; 186 while ((line = reader.readLine()) != null) { 187 String[] featureMeanStddev = line.split("\\t"); 188 this.meanStddevMap.put( 189 featureMeanStddev[0], 190 new MeanStddevTuple( 191 Double.parseDouble(featureMeanStddev[1]), 192 Double.parseDouble(featureMeanStddev[2]))); 193 } 194 195 reader.close(); 196 197 this.isTrained = true; 198 } 199 200 public static class MeanStddevTuple { 201 202 public MeanStddevTuple(double mean, double stddev) { 203 this.mean = mean; 204 this.stddev = stddev; 205 } 206 207 public double mean; 208 209 public double stddev; 210 } 211 212 public static class MeanVarianceRunningStat implements Serializable { 213 214 private static final long serialVersionUID = 1L; 215 216 public MeanVarianceRunningStat() { 217 this.clear(); 218 } 219 220 public void init(int n, double mean, double variance) { 221 this.numSamples = n; 222 this.meanNew = mean; 223 this.varNew = variance; 224 } 225 226 public void add(double x) { 227 numSamples++; 228 229 if (this.numSamples == 1) { 230 meanOld = meanNew = x; 231 varOld = 0.0; 232 } else { 233 meanNew = meanOld + (x - meanOld) / numSamples; 234 varNew = varOld + (x - meanOld) * (x - meanNew); 235 // set up for next iteration 236 meanOld = meanNew; 237 varOld = varNew; 238 } 239 } 240 241 // 1-2 * 1-1.5 = .5 242 243 public void clear() { 244 this.numSamples = 0; 245 } 246 247 public int getNumSamples() { 248 return this.numSamples; 249 } 250 251 public double mean() { 252 return (this.numSamples > 0) ? meanNew : 0.0; 253 } 254 255 public double variance() { 256 return (this.numSamples > 1) ? varNew / (this.numSamples) : 0.0; 257 } 258 259 public double stddev() { 260 return Math.sqrt(this.variance()); 261 } 262 263 public double variancePop() { 264 return (this.numSamples > 1) ? varNew / (this.numSamples - 1) : 0.0; 265 } 266 267 public double stddevPop() { 268 return Math.sqrt(this.variancePop()); 269 } 270 271 private void writeObject(ObjectOutputStream out) throws IOException { 272 out.defaultWriteObject(); 273 out.writeInt(numSamples); 274 out.writeDouble(meanNew); 275 out.writeDouble(varNew); 276 277 } 278 279 private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { 280 in.defaultReadObject(); 281 numSamples = in.readInt(); 282 meanOld = meanNew = in.readDouble(); 283 varOld = varNew = in.readDouble(); 284 } 285 286 private int numSamples; 287 288 private double meanOld; 289 290 private double meanNew; 291 292 private double varOld; 293 294 private double varNew; 295 } 296}