001/*
002 * Copyright (c) 2011, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024package org.cleartk.timeml.tlink;
025
026import java.util.ArrayList;
027import java.util.Arrays;
028import java.util.Collections;
029import java.util.HashMap;
030import java.util.HashSet;
031import java.util.List;
032import java.util.Map;
033import java.util.Set;
034
035import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
036import org.apache.uima.jcas.JCas;
037import org.apache.uima.util.Level;
038import org.cleartk.ml.CleartkAnnotator;
039import org.cleartk.ml.Feature;
040import org.cleartk.ml.Instance;
041import org.cleartk.ml.feature.extractor.FeatureExtractor1;
042import org.cleartk.ml.feature.extractor.FeatureExtractor2;
043import org.cleartk.ml.feature.extractor.NamingExtractor1;
044import org.cleartk.syntax.constituent.type.TreebankNode;
045import org.cleartk.syntax.constituent.type.TreebankNodeUtil.TreebankNodePath;
046import org.cleartk.timeml.type.Anchor;
047import org.cleartk.timeml.type.TemporalLink;
048import org.cleartk.token.type.Sentence;
049import org.apache.uima.fit.util.JCasUtil;
050
051import com.google.common.base.Function;
052import com.google.common.collect.HashBasedTable;
053import com.google.common.collect.Lists;
054import com.google.common.collect.Ordering;
055import com.google.common.collect.Table;
056import com.google.common.collect.Table.Cell;
057
058/**
059 * <br>
060 * Copyright (c) 2011, Regents of the University of Colorado <br>
061 * All rights reserved.
062 * 
063 * @author Steven Bethard
064 */
065public abstract class TemporalLinkAnnotator_ImplBase<SOURCE extends Anchor, TARGET extends Anchor>
066    extends CleartkAnnotator<String> {
067
068  public static Map<String, String> REVERSE_RELATION = new HashMap<String, String>();
069  static {
070    // TimeML
071    REVERSE_RELATION.put("BEFORE", "AFTER");
072    REVERSE_RELATION.put("AFTER", "BEFORE");
073    REVERSE_RELATION.put("INCLUDES", "IS_INCLUDED");
074    REVERSE_RELATION.put("IS_INCLUDED", "INCLUDES");
075    REVERSE_RELATION.put("DURING", "DURING_INV");
076    REVERSE_RELATION.put("DURING_INV", "DURING");
077    REVERSE_RELATION.put("SIMULTANEOUS", "SIMULTANEOUS");
078    REVERSE_RELATION.put("IAFTER", "IBEFORE");
079    REVERSE_RELATION.put("IBEFORE", "IAFTER");
080    REVERSE_RELATION.put("IDENTITY", "IDENTITY");
081    REVERSE_RELATION.put("BEGINS", "BEGUN_BY");
082    REVERSE_RELATION.put("ENDS", "ENDED_BY");
083    REVERSE_RELATION.put("BEGUN_BY", "BEGINS");
084    REVERSE_RELATION.put("ENDED_BY", "ENDS");
085    // TempEval
086    REVERSE_RELATION.put("OVERLAP", "OVERLAP");
087    REVERSE_RELATION.put("OVERLAP-OR-AFTER", "BEFORE-OR-OVERLAP");
088    REVERSE_RELATION.put("BEFORE-OR-OVERLAP", "OVERLAP-OR-AFTER");
089    REVERSE_RELATION.put("VAGUE", "VAGUE");
090    REVERSE_RELATION.put("UNKNOWN", "UNKNOWN");
091    REVERSE_RELATION.put("NONE", "NONE");
092  }
093
094  private Class<SOURCE> sourceClass;
095
096  private Class<TARGET> targetClass;
097
098  private Set<String> trainingRelationTypes;
099  
100  private static final String NO_RELATION = "-NO-RELATION-";
101
102  protected List<FeatureExtractor1<SOURCE>> sourceExtractors;
103
104  protected List<FeatureExtractor1<TARGET>> targetExtractors;
105
106  protected List<FeatureExtractor2<Anchor,Anchor>> betweenExtractors;
107
108  protected class SourceTargetPair {
109
110    public SOURCE source;
111
112    public TARGET target;
113
114    public SourceTargetPair(SOURCE source, TARGET target) {
115      this.source = source;
116      this.target = target;
117    }
118  }
119
120  public TemporalLinkAnnotator_ImplBase(
121      Class<SOURCE> sourceClass,
122      Class<TARGET> targetClass,
123      String... trainingRelationTypes) {
124    this.sourceClass = sourceClass;
125    this.targetClass = targetClass;
126    this.trainingRelationTypes = new HashSet<String>(Arrays.asList(trainingRelationTypes));
127    this.sourceExtractors = Lists.newArrayList();
128    this.targetExtractors = Lists.newArrayList();
129    this.betweenExtractors = Lists.newArrayList();
130  }
131
132  protected void setSourceExtractors(List<FeatureExtractor1<SOURCE>> extractors) {
133    this.sourceExtractors = new ArrayList<FeatureExtractor1<SOURCE>>();
134    for (FeatureExtractor1<SOURCE> extractor : extractors) {
135      this.sourceExtractors.add(new NamingExtractor1<SOURCE>("Source", extractor));
136    }
137  }
138
139  protected void setTargetExtractors(List<FeatureExtractor1<TARGET>> extractors) {
140    this.targetExtractors = new ArrayList<FeatureExtractor1<TARGET>>();
141    for (FeatureExtractor1<TARGET> extractor : extractors) {
142      this.targetExtractors.add(new NamingExtractor1<TARGET>("Target", extractor));
143    }
144  }
145
146  protected void setBetweenExtractors(List<FeatureExtractor2<Anchor, Anchor>> extractors) {
147    this.betweenExtractors = extractors;
148  }
149
150  /**
151   * Returns the (source, target) anchor pairs for which a relation should be classified.
152   */
153  protected abstract List<SourceTargetPair> getSourceTargetPairs(JCas jCas);
154
155  @Override
156  public void process(JCas jCas) throws AnalysisEngineProcessException {
157    // collect all annotated relations
158    Table<SOURCE, TARGET, String> links = HashBasedTable.create();
159    for (TemporalLink tlink : JCasUtil.select(jCas, TemporalLink.class)) {
160      Anchor sourceAnchor = tlink.getSource();
161      Anchor targetAnchor = tlink.getTarget();
162
163      // collect the relation from source to target
164      if (this.sourceClass.isInstance(sourceAnchor) && this.targetClass.isInstance(targetAnchor)) {
165        SOURCE source = this.sourceClass.cast(sourceAnchor);
166        TARGET target = this.targetClass.cast(targetAnchor);
167        String relation = tlink.getRelationType();
168        links.put(source, target, relation);
169      }
170
171      // collect the (reversed) relation from target to source
172      if (this.sourceClass.isInstance(targetAnchor) && this.targetClass.isInstance(sourceAnchor)) {
173        SOURCE source = this.sourceClass.cast(targetAnchor);
174        TARGET target = this.targetClass.cast(sourceAnchor);
175        String relation = REVERSE_RELATION.get(tlink.getRelationType());
176        if (relation == null) {
177          throw new UnsupportedOperationException("Unknown relation: " + tlink.getRelationType());
178        }
179        links.put(source, target, relation);
180      }
181    }
182
183    // for each pair of anchors, write training data or classify the relation
184    for (SourceTargetPair pair : this.getSourceTargetPairs(jCas)) {
185      SOURCE source = pair.source;
186      TARGET target = pair.target;
187
188      // extract features
189      List<Feature> features = new ArrayList<Feature>();
190      for (FeatureExtractor1<SOURCE> extractor : this.sourceExtractors) {
191        features.addAll(extractor.extract(jCas, source));
192      }
193      for (FeatureExtractor1<TARGET> extractor : this.targetExtractors) {
194        features.addAll(extractor.extract(jCas, target));
195      }
196      for (FeatureExtractor2<Anchor, Anchor> extractor : this.betweenExtractors) {
197        features.addAll(extractor.extract(jCas, source, target));
198      }
199
200      // during training, write an instance if this pair was labeled
201      if (this.isTraining()) {
202        String relation = links.remove(source, target);
203        if (relation != null) {
204          if (!this.trainingRelationTypes.isEmpty() && !this.trainingRelationTypes.contains(relation)) {
205            relation = NO_RELATION;
206          }
207          this.dataWriter.write(new Instance<String>(relation, features));
208        }
209      } else {
210        String relation = this.classifier.classify(features);
211        if (!NO_RELATION.equals(relation)) {
212          int offset = jCas.getDocumentText().length();
213          TemporalLink tlink = new TemporalLink(jCas, offset, offset);
214          tlink.setSource(source);
215          tlink.setTarget(target);
216          tlink.setRelationType(relation);
217          tlink.addToIndexes();
218        }
219      }
220    }
221
222    // log a message for any links that were annotated but not used
223    if (!links.isEmpty()) {
224      
225      // map anchors to the sentences that contain them
226      Map<Anchor, Sentence> sentences = new HashMap<Anchor, Sentence>();
227      for (Sentence sentence : JCasUtil.select(jCas, Sentence.class)) {
228        for (SOURCE source : JCasUtil.selectCovered(jCas, this.sourceClass, sentence)) {
229          sentences.put(source, sentence);
230        }
231        for (TARGET target : JCasUtil.selectCovered(jCas, this.targetClass, sentence)) {
232          sentences.put(target, sentence);
233        }
234      }
235
236      // sort relations by the location of their source
237      List<Cell<SOURCE, TARGET, String>> cells = Lists.newArrayList(links.cellSet());
238      Ordering<Cell<SOURCE, TARGET, String>> order = Ordering.natural().onResultOf(
239          new Function<Cell<SOURCE, TARGET, String>, Integer>() {
240            @Override
241            public Integer apply(Cell<SOURCE, TARGET, String> cell) {
242              return cell.getRowKey().getBegin();
243            }
244          });
245      Collections.sort(cells, order);
246
247      // assemble an error message
248      StringBuilder errorBuilder = new StringBuilder();
249      errorBuilder.append("Missed ").append(links.size()).append(" TLINK(s)\n");
250      for (Cell<SOURCE, TARGET, String> cell : cells) {
251        SOURCE source = cell.getRowKey();
252        TARGET target = cell.getColumnKey();
253        String relation = cell.getValue();
254        Sentence sent1 = sentences.get(source);
255        Sentence sent2 = sentences.get(target);
256        errorBuilder.append(String.format(
257            "%s(%s, %s)\n%s\n%s\n",
258            relation,
259            source.getCoveredText(),
260            target.getCoveredText(),
261            sent1 == null ? null : sent1.getCoveredText(),
262            sent2 == null ? null : sent2.getCoveredText()));
263      }
264      this.getContext().getLogger().log(Level.FINE, errorBuilder.toString());
265    }
266  }
267
268  protected static String noLeavesPath(TreebankNodePath path) {
269    if (path.getCommonAncestor() == null) {
270      return null;
271    }
272    List<String> sourceTypes = new ArrayList<String>();
273    for (TreebankNode node : path.getSourceToAncestorPath()) {
274      if (!node.getLeaf()) {
275        sourceTypes.add(node.getNodeType());
276      }
277    }
278    List<String> targetTypes = new ArrayList<String>();
279    for (TreebankNode node : path.getTargetToAncestorPath()) {
280      if (!node.getLeaf()) {
281        targetTypes.add(node.getNodeType());
282      }
283    }
284    Collections.reverse(targetTypes);
285    StringBuilder builder = new StringBuilder();
286    for (String type : sourceTypes) {
287      builder.append(type).append('>');
288    }
289    builder.append(path.getCommonAncestor().getNodeType());
290    for (String type : targetTypes) {
291      builder.append('<').append(type);
292    }
293    return builder.toString();
294  }
295}