001/** 
002 * Copyright (c) 2012, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024
025package org.cleartk.ml.feature.transform.extractor;
026
027import java.io.BufferedReader;
028import java.io.BufferedWriter;
029import java.io.File;
030import java.io.FileReader;
031import java.io.FileWriter;
032import java.io.IOException;
033import java.net.URI;
034import java.util.ArrayList;
035import java.util.HashSet;
036import java.util.List;
037import java.util.Locale;
038import java.util.Set;
039
040import org.apache.uima.jcas.JCas;
041import org.apache.uima.jcas.tcas.Annotation;
042import org.cleartk.ml.Feature;
043import org.cleartk.ml.Instance;
044import org.cleartk.ml.feature.extractor.CleartkExtractorException;
045import org.cleartk.ml.feature.extractor.FeatureExtractor1;
046import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase;
047import org.cleartk.ml.feature.transform.TransformableFeature;
048
049import com.google.common.collect.LinkedHashMultiset;
050import com.google.common.collect.Multiset;
051
052/**
053 * Transforms count features produced by its subextractor into TF*IDF values
054 * <p>
055 * 
056 * <br>
057 * Copyright (c) 2012, Regents of the University of Colorado <br>
058 * All rights reserved.
059 * 
060 * @author Lee Becker
061 * 
062 */
063public class TfidfExtractor<OUTCOME_T, FOCUS_T extends Annotation> extends
064    OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> {
065
066  protected FeatureExtractor1<FOCUS_T> subExtractor;
067
068  protected boolean isTrained;
069
070  protected IDFMap idfMap;
071
072  public TfidfExtractor(String name) {
073    this(name, null);
074  }
075
076  /**
077   * 
078   * @param extractor
079   *          - This assumes that any extractors passed in will produce counts of some variety
080   */
081  public TfidfExtractor(String name, FeatureExtractor1<FOCUS_T> extractor) {
082    super(name);
083    this.subExtractor = extractor;
084    this.isTrained = false;
085    this.idfMap = new IDFMap();
086  }
087
088  @Override
089  protected Feature transform(Feature feature) {
090    int tf = (Integer) feature.getValue();
091    double tfidf = tf * this.idfMap.getIDF(feature.getName());
092    return new Feature("TF-IDF_" + feature.getName(), tfidf);
093  }
094
095  @Override
096  public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException {
097
098    List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
099    List<Feature> result = new ArrayList<Feature>();
100    if (this.isTrained) {
101      // We have trained / loaded a tf*idf model, so now fix up the values
102      for (Feature feature : extracted) {
103        result.add(this.transform(feature));
104      }
105    } else {
106      // We haven't trained this extractor yet, so just mark the existing features
107      // for future modification, by creating one mega container feature
108      result.add(new TransformableFeature(this.name, extracted));
109    }
110
111    return result;
112  }
113
114  protected IDFMap createIdfMap(Iterable<Instance<OUTCOME_T>> instances) {
115    IDFMap newIdfMap = new IDFMap();
116
117    // Add instance's term frequencies to the global counts
118    for (Instance<OUTCOME_T> instance : instances) {
119
120      Set<String> featureNames = new HashSet<String>();
121      // Grab the matching tf*idf features from the set of all features in an instance
122      for (Feature feature : instance.getFeatures()) {
123        if (this.isTransformable(feature)) {
124          // tf*idf features contain a list of features, these are actually what get added
125          // to our document frequency map
126          for (Feature untransformedFeature : ((TransformableFeature) feature).getFeatures()) {
127            featureNames.add(untransformedFeature.getName());
128          }
129        }
130      }
131
132      for (String featureName : featureNames) {
133        newIdfMap.add(featureName);
134      }
135      newIdfMap.incTotalDocumentCount();
136
137    }
138    return newIdfMap;
139  }
140
141  @Override
142  public void train(Iterable<Instance<OUTCOME_T>> instances) {
143    this.idfMap = this.createIdfMap(instances);
144    this.isTrained = true;
145  }
146
147  @Override
148  public void save(URI documentFreqDataURI) throws IOException {
149    this.idfMap.save(documentFreqDataURI);
150  }
151
152  @Override
153  public void load(URI documentFreqDataURI) throws IOException {
154    this.idfMap.load(documentFreqDataURI);
155    this.isTrained = true;
156  }
157
158  protected static class IDFMap {
159    private Multiset<String> documentFreqMap;
160
161    private int totalDocumentCount;
162
163    public IDFMap() {
164      this.documentFreqMap = LinkedHashMultiset.create();
165      this.totalDocumentCount = 0;
166    }
167
168    public void add(String term) {
169      this.documentFreqMap.add(term);
170    }
171
172    public void incTotalDocumentCount() {
173      this.totalDocumentCount++;
174    }
175
176    public int getTotalDocumentCount() {
177      return this.totalDocumentCount;
178    }
179
180    public int getDF(String term) {
181      return this.documentFreqMap.count(term);
182    }
183
184    public double getIDF(String term) {
185      int df = this.getDF(term);
186
187      // see issue 396 for discussion about the ratio here:
188      // https://code.google.com/p/cleartk/issues/detail?id=396
189      // In short, we add 1 to the document frequency for smoothing purposes so that unseen words
190      // will not generate a NaN. We add 2 to the numerator to make sure that we get a positive
191      // value for words that appear in every document (i.e. when df == totalDocumentCount)
192      return Math.log((this.totalDocumentCount + 2) / (double) (df + 1));
193    }
194
195    public void save(URI outputURI) throws IOException {
196      File out = new File(outputURI);
197      BufferedWriter writer = null;
198      writer = new BufferedWriter(new FileWriter(out));
199      writer.append(String.format(Locale.ROOT, "NUM DOCUMENTS\t%d\n", this.totalDocumentCount));
200      for (Multiset.Entry<String> entry : this.documentFreqMap.entrySet()) {
201        writer.append(String.format(Locale.ROOT, "%s\t%d\n", entry.getElement(), entry.getCount()));
202      }
203      writer.close();
204    }
205
206    public void load(URI inputURI) throws IOException {
207      File in = new File(inputURI);
208      BufferedReader reader = null;
209      // this.documentFreqMap = LinkedHashMultiset.create();
210      reader = new BufferedReader(new FileReader(in));
211      // First line specifies the number of documents
212      String firstLine = reader.readLine();
213      String[] keyValuePair = firstLine.split("\\t");
214      this.totalDocumentCount = Integer.parseInt(keyValuePair[1]);
215
216      // The rest of the lines are the term counts
217      String line = null;
218      while ((line = reader.readLine()) != null) {
219        String[] termFreqPair = line.split("\\t");
220        this.documentFreqMap.add(termFreqPair[0], Integer.parseInt(termFreqPair[1]));
221      }
222      reader.close();
223    }
224  }
225}