001/** 
002 * Copyright (c) 2012, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024
025package org.cleartk.ml.feature.transform.extractor;
026
027import java.io.BufferedReader;
028import java.io.BufferedWriter;
029import java.io.File;
030import java.io.FileReader;
031import java.io.FileWriter;
032import java.io.IOException;
033import java.io.ObjectInputStream;
034import java.io.ObjectOutputStream;
035import java.io.Serializable;
036import java.net.URI;
037import java.util.ArrayList;
038import java.util.HashMap;
039import java.util.List;
040import java.util.Locale;
041import java.util.Map;
042
043import org.apache.uima.jcas.JCas;
044import org.apache.uima.jcas.tcas.Annotation;
045import org.cleartk.ml.Feature;
046import org.cleartk.ml.Instance;
047import org.cleartk.ml.feature.extractor.CleartkExtractorException;
048import org.cleartk.ml.feature.extractor.FeatureExtractor1;
049import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase;
050import org.cleartk.ml.feature.transform.TransformableFeature;
051
052/**
053 * Scales features produced by its subextractor to have mean=0, stddev=1 for a given feature
054 * <p>
055 * 
056 * Copyright (c) 2012, Regents of the University of Colorado <br>
057 * All rights reserved.
058 * 
059 * @author Lee Becker
060 * 
061 */
062public class ZeroMeanUnitStddevExtractor<OUTCOME_T, FOCUS_T extends Annotation> extends
063    OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> {
064
065  private FeatureExtractor1<FOCUS_T> subExtractor;
066
067  private boolean isTrained;
068
069  // This is read in after training for use in transformation
070  private Map<String, MeanStddevTuple> meanStddevMap;
071
072  public ZeroMeanUnitStddevExtractor(String name) {
073    this(name, null);
074  }
075
076  public ZeroMeanUnitStddevExtractor(String name, FeatureExtractor1<FOCUS_T> subExtractor) {
077    super(name);
078    this.subExtractor = subExtractor;
079    this.isTrained = false;
080  }
081
082  @Override
083  protected Feature transform(Feature feature) {
084    String featureName = feature.getName();
085    MeanStddevTuple stats = this.meanStddevMap.get(featureName);
086    double value = ((Number) feature.getValue()).doubleValue();
087
088    double zmus = 0.0d;
089    if (stats != null && stats.stddev > 0) {
090      zmus = (value - stats.mean) / stats.stddev;
091      return new Feature("ZMUS_" + featureName, zmus);
092    } else {
093      return null;
094    }
095  }
096
097  @Override
098  public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException {
099
100    List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
101    List<Feature> result = new ArrayList<Feature>();
102    if (this.isTrained) {
103      // We have trained / loaded a ZMUS model, so now fix up the values
104      for (Feature feature : extracted) {
105        Feature transformedFeature = this.transform(feature);
106        if (transformedFeature != null)
107          result.add(transformedFeature);
108      }
109    } else {
110      // We haven't trained this extractor yet, so just mark the existing features
111      // for future modification, by creating one mega container feature
112      result.add(new TransformableFeature(this.name, extracted));
113    }
114
115    return result;
116  }
117
118  @Override
119  public void train(Iterable<Instance<OUTCOME_T>> instances) {
120    Map<String, MeanVarianceRunningStat> featureStatsMap = new HashMap<String, MeanVarianceRunningStat>();
121
122    // keep a running mean and standard deviation for all applicable features
123    for (Instance<OUTCOME_T> instance : instances) {
124      // Grab the matching zmus (zero mean, unit stddev) features from the set of all features in an
125      // instance
126      for (Feature feature : instance.getFeatures()) {
127        if (this.isTransformable(feature)) {
128          // ZMUS features contain a list of features, these are actually what get added
129          // to our document frequency map
130          for (Feature untransformedFeature : ((TransformableFeature) feature).getFeatures()) {
131            String featureName = untransformedFeature.getName();
132            Object featureValue = untransformedFeature.getValue();
133            if (featureValue instanceof Number) {
134              MeanVarianceRunningStat stats;
135              if (featureStatsMap.containsKey(featureName)) {
136                stats = featureStatsMap.get(featureName);
137              } else {
138                stats = new MeanVarianceRunningStat();
139                featureStatsMap.put(featureName, stats);
140              }
141              stats.add(((Number) featureValue).doubleValue());
142            } else {
143              throw new IllegalArgumentException("Cannot normalize non-numeric feature values");
144            }
145          }
146        }
147      }
148    }
149
150    this.meanStddevMap = new HashMap<String, MeanStddevTuple>();
151    for (Map.Entry<String, MeanVarianceRunningStat> entry : featureStatsMap.entrySet()) {
152      MeanVarianceRunningStat stats = entry.getValue();
153      this.meanStddevMap.put(entry.getKey(), new MeanStddevTuple(stats.mean(), stats.stddev()));
154    }
155
156    this.isTrained = true;
157  }
158
159  @Override
160  public void save(URI zmusDataUri) throws IOException {
161    // Write out tab separated values: feature_name, mean, stddev
162    File out = new File(zmusDataUri);
163    BufferedWriter writer = null;
164    writer = new BufferedWriter(new FileWriter(out));
165
166    for (Map.Entry<String, MeanStddevTuple> entry : this.meanStddevMap.entrySet()) {
167      MeanStddevTuple tuple = entry.getValue();
168      writer.append(String.format(
169          Locale.ROOT,
170          "%s\t%f\t%f\n",
171          entry.getKey(),
172          tuple.mean,
173          tuple.stddev));
174    }
175    writer.close();
176  }
177
178  @Override
179  public void load(URI zmusDataUri) throws IOException {
180    // Reads in tab separated values (feature name, mean, stddev)
181    File in = new File(zmusDataUri);
182    BufferedReader reader = null;
183    this.meanStddevMap = new HashMap<String, MeanStddevTuple>();
184    reader = new BufferedReader(new FileReader(in));
185    String line = null;
186    while ((line = reader.readLine()) != null) {
187      String[] featureMeanStddev = line.split("\\t");
188      this.meanStddevMap.put(
189          featureMeanStddev[0],
190          new MeanStddevTuple(
191              Double.parseDouble(featureMeanStddev[1]),
192              Double.parseDouble(featureMeanStddev[2])));
193    }
194
195    reader.close();
196
197    this.isTrained = true;
198  }
199
200  public static class MeanStddevTuple {
201
202    public MeanStddevTuple(double mean, double stddev) {
203      this.mean = mean;
204      this.stddev = stddev;
205    }
206
207    public double mean;
208
209    public double stddev;
210  }
211
212  public static class MeanVarianceRunningStat implements Serializable {
213
214    private static final long serialVersionUID = 1L;
215
216    public MeanVarianceRunningStat() {
217      this.clear();
218    }
219
220    public void init(int n, double mean, double variance) {
221      this.numSamples = n;
222      this.meanNew = mean;
223      this.varNew = variance;
224    }
225
226    public void add(double x) {
227      numSamples++;
228
229      if (this.numSamples == 1) {
230        meanOld = meanNew = x;
231        varOld = 0.0;
232      } else {
233        meanNew = meanOld + (x - meanOld) / numSamples;
234        varNew = varOld + (x - meanOld) * (x - meanNew);
235        // set up for next iteration
236        meanOld = meanNew;
237        varOld = varNew;
238      }
239    }
240
241    // 1-2 * 1-1.5 = .5
242
243    public void clear() {
244      this.numSamples = 0;
245    }
246
247    public int getNumSamples() {
248      return this.numSamples;
249    }
250
251    public double mean() {
252      return (this.numSamples > 0) ? meanNew : 0.0;
253    }
254
255    public double variance() {
256      return (this.numSamples > 1) ? varNew / (this.numSamples) : 0.0;
257    }
258
259    public double stddev() {
260      return Math.sqrt(this.variance());
261    }
262
263    public double variancePop() {
264      return (this.numSamples > 1) ? varNew / (this.numSamples - 1) : 0.0;
265    }
266
267    public double stddevPop() {
268      return Math.sqrt(this.variancePop());
269    }
270
271    private void writeObject(ObjectOutputStream out) throws IOException {
272      out.defaultWriteObject();
273      out.writeInt(numSamples);
274      out.writeDouble(meanNew);
275      out.writeDouble(varNew);
276
277    }
278
279    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
280      in.defaultReadObject();
281      numSamples = in.readInt();
282      meanOld = meanNew = in.readDouble();
283      varOld = varNew = in.readDouble();
284    }
285
286    private int numSamples;
287
288    private double meanOld;
289
290    private double meanNew;
291
292    private double varOld;
293
294    private double varNew;
295  }
296}