001/** 002 * Copyright (c) 2012, Regents of the University of Colorado 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without 006 * modification, are permitted provided that the following conditions are met: 007 * 008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 011 * 012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 022 * POSSIBILITY OF SUCH DAMAGE. 023 */ 024 025package org.cleartk.ml.feature.transform.extractor; 026 027import java.io.BufferedReader; 028import java.io.BufferedWriter; 029import java.io.File; 030import java.io.FileReader; 031import java.io.FileWriter; 032import java.io.IOException; 033import java.net.URI; 034import java.util.ArrayList; 035import java.util.HashSet; 036import java.util.List; 037import java.util.Locale; 038import java.util.Set; 039 040import org.apache.uima.jcas.JCas; 041import org.apache.uima.jcas.tcas.Annotation; 042import org.cleartk.ml.Feature; 043import org.cleartk.ml.Instance; 044import org.cleartk.ml.feature.extractor.CleartkExtractorException; 045import org.cleartk.ml.feature.extractor.FeatureExtractor1; 046import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase; 047import org.cleartk.ml.feature.transform.TransformableFeature; 048 049import com.google.common.collect.LinkedHashMultiset; 050import com.google.common.collect.Multiset; 051 052/** 053 * Transforms count features produced by its subextractor into TF*IDF values 054 * <p> 055 * 056 * <br> 057 * Copyright (c) 2012, Regents of the University of Colorado <br> 058 * All rights reserved. 059 * 060 * @author Lee Becker 061 * 062 */ 063public class TfidfExtractor<OUTCOME_T, FOCUS_T extends Annotation> extends 064 OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> { 065 066 protected FeatureExtractor1<FOCUS_T> subExtractor; 067 068 protected boolean isTrained; 069 070 protected IDFMap idfMap; 071 072 public TfidfExtractor(String name) { 073 this(name, null); 074 } 075 076 /** 077 * 078 * @param extractor 079 * - This assumes that any extractors passed in will produce counts of some variety 080 */ 081 public TfidfExtractor(String name, FeatureExtractor1<FOCUS_T> extractor) { 082 super(name); 083 this.subExtractor = extractor; 084 this.isTrained = false; 085 this.idfMap = new IDFMap(); 086 } 087 088 @Override 089 protected Feature transform(Feature feature) { 090 int tf = (Integer) feature.getValue(); 091 double tfidf = tf * this.idfMap.getIDF(feature.getName()); 092 return new Feature("TF-IDF_" + feature.getName(), tfidf); 093 } 094 095 @Override 096 public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException { 097 098 List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation); 099 List<Feature> result = new ArrayList<Feature>(); 100 if (this.isTrained) { 101 // We have trained / loaded a tf*idf model, so now fix up the values 102 for (Feature feature : extracted) { 103 result.add(this.transform(feature)); 104 } 105 } else { 106 // We haven't trained this extractor yet, so just mark the existing features 107 // for future modification, by creating one mega container feature 108 result.add(new TransformableFeature(this.name, extracted)); 109 } 110 111 return result; 112 } 113 114 protected IDFMap createIdfMap(Iterable<Instance<OUTCOME_T>> instances) { 115 IDFMap newIdfMap = new IDFMap(); 116 117 // Add instance's term frequencies to the global counts 118 for (Instance<OUTCOME_T> instance : instances) { 119 120 Set<String> featureNames = new HashSet<String>(); 121 // Grab the matching tf*idf features from the set of all features in an instance 122 for (Feature feature : instance.getFeatures()) { 123 if (this.isTransformable(feature)) { 124 // tf*idf features contain a list of features, these are actually what get added 125 // to our document frequency map 126 for (Feature untransformedFeature : ((TransformableFeature) feature).getFeatures()) { 127 featureNames.add(untransformedFeature.getName()); 128 } 129 } 130 } 131 132 for (String featureName : featureNames) { 133 newIdfMap.add(featureName); 134 } 135 newIdfMap.incTotalDocumentCount(); 136 137 } 138 return newIdfMap; 139 } 140 141 @Override 142 public void train(Iterable<Instance<OUTCOME_T>> instances) { 143 this.idfMap = this.createIdfMap(instances); 144 this.isTrained = true; 145 } 146 147 @Override 148 public void save(URI documentFreqDataURI) throws IOException { 149 this.idfMap.save(documentFreqDataURI); 150 } 151 152 @Override 153 public void load(URI documentFreqDataURI) throws IOException { 154 this.idfMap.load(documentFreqDataURI); 155 this.isTrained = true; 156 } 157 158 protected static class IDFMap { 159 private Multiset<String> documentFreqMap; 160 161 private int totalDocumentCount; 162 163 public IDFMap() { 164 this.documentFreqMap = LinkedHashMultiset.create(); 165 this.totalDocumentCount = 0; 166 } 167 168 public void add(String term) { 169 this.documentFreqMap.add(term); 170 } 171 172 public void incTotalDocumentCount() { 173 this.totalDocumentCount++; 174 } 175 176 public int getTotalDocumentCount() { 177 return this.totalDocumentCount; 178 } 179 180 public int getDF(String term) { 181 return this.documentFreqMap.count(term); 182 } 183 184 public double getIDF(String term) { 185 int df = this.getDF(term); 186 187 // see issue 396 for discussion about the ratio here: 188 // https://code.google.com/p/cleartk/issues/detail?id=396 189 // In short, we add 1 to the document frequency for smoothing purposes so that unseen words 190 // will not generate a NaN. We add 2 to the numerator to make sure that we get a positive 191 // value for words that appear in every document (i.e. when df == totalDocumentCount) 192 return Math.log((this.totalDocumentCount + 2) / (double) (df + 1)); 193 } 194 195 public void save(URI outputURI) throws IOException { 196 File out = new File(outputURI); 197 BufferedWriter writer = null; 198 writer = new BufferedWriter(new FileWriter(out)); 199 writer.append(String.format(Locale.ROOT, "NUM DOCUMENTS\t%d\n", this.totalDocumentCount)); 200 for (Multiset.Entry<String> entry : this.documentFreqMap.entrySet()) { 201 writer.append(String.format(Locale.ROOT, "%s\t%d\n", entry.getElement(), entry.getCount())); 202 } 203 writer.close(); 204 } 205 206 public void load(URI inputURI) throws IOException { 207 File in = new File(inputURI); 208 BufferedReader reader = null; 209 // this.documentFreqMap = LinkedHashMultiset.create(); 210 reader = new BufferedReader(new FileReader(in)); 211 // First line specifies the number of documents 212 String firstLine = reader.readLine(); 213 String[] keyValuePair = firstLine.split("\\t"); 214 this.totalDocumentCount = Integer.parseInt(keyValuePair[1]); 215 216 // The rest of the lines are the term counts 217 String line = null; 218 while ((line = reader.readLine()) != null) { 219 String[] termFreqPair = line.split("\\t"); 220 this.documentFreqMap.add(termFreqPair[0], Integer.parseInt(termFreqPair[1])); 221 } 222 reader.close(); 223 } 224 } 225}