001/** 
002 * Copyright (c) 2011, Regents of the University of Colorado 
003 * All rights reserved.
004 * 
005 * Redistribution and use in source and binary forms, with or without
006 * modification, are permitted provided that the following conditions are met:
007 * 
008 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
009 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
010 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
011 * 
012 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
013 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
014 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
015 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
016 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
017 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
018 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
019 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
020 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
021 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
022 * POSSIBILITY OF SUCH DAMAGE. 
023 */
024package org.cleartk.ml.jar;
025
026import java.io.BufferedInputStream;
027import java.io.BufferedOutputStream;
028import java.io.File;
029import java.io.FileInputStream;
030import java.io.FileOutputStream;
031import java.io.IOException;
032import java.io.InputStream;
033import java.util.jar.Attributes;
034import java.util.jar.JarInputStream;
035import java.util.jar.JarOutputStream;
036import java.util.jar.Manifest;
037
038import org.cleartk.ml.DataWriter;
039import org.cleartk.ml.SequenceDataWriter;
040
041/**
042 * Superclass for builders which package classifiers as jar files. Saves a manifest from which new
043 * instances of this class can later be loaded.
044 * 
045 * ClassifierBuilder subclasses that use feature or outcome encoders should instead subclass from
046 * {@link EncodingJarClassifierBuilder}.
047 * 
048 * Subclasses will typically override:
049 * <ul>
050 * <li>{@link #saveToTrainingDirectory(File)} to add items to the model training directory</li>
051 * <li>{@link #packageClassifier(File, JarOutputStream)} to copy items to the classifier jar</li>
052 * <li>{@link #unpackageClassifier(JarInputStream)} to load items from the classifier jar</li>
053 * <li>{@link #newClassifier()} to create a classifier from the loaded attributes</li>
054 * </ul>
055 * 
056 * <br>
057 * Copyright (c) 2011, Regents of the University of Colorado <br>
058 * All rights reserved.
059 * 
060 * @author Steven Bethard
061 */
062public abstract class JarClassifierBuilder<CLASSIFIER_TYPE> {
063
064  /**
065   * The name of the attribute where the classifier builder class is stored.
066   */
067  private static final Attributes.Name CLASSIFIER_BUILDER_ATTRIBUTE_NAME = new Attributes.Name(
068      "classifierBuilderClass");
069
070  /**
071   * The manifest associated with this classifier builder. The manifest will be saved to directories
072   * and jar files (via {@link #saveToTrainingDirectory(File)} and {@link #packageClassifier(File)}
073   * and can be used to load a new instance of the classifier builder (via
074   * {@link #loadClassifier(InputStream)} and {@link #loadClassifierFromTrainingDirectory(File)}.
075   */
076  protected Manifest manifest;
077
078  /**
079   * Loads a classifier builder from manifest in the training directory.
080   * 
081   * @param dir
082   *          The training directory where the classifier builder was written by a call to
083   *          {@link #saveToTrainingDirectory(File)}. This is typically the same directory as was
084   *          used for {@link DirectoryDataWriterFactory#PARAM_OUTPUT_DIRECTORY}.
085   * @return A new classifier builder.
086   */
087  public static JarClassifierBuilder<?> fromTrainingDirectory(File dir) throws IOException {
088    InputStream stream = new BufferedInputStream(new FileInputStream(getManifestFile(dir)));
089    Manifest manifest = new Manifest(stream);
090    stream.close();
091    return fromManifest(manifest);
092  }
093
094  /**
095   * Loads a classifier builder from a manifest.
096   * 
097   * @param manifest
098   *          The classifier manifest, either from a training directory or from a classifier jar.
099   * @return A new classifier builder.
100   */
101  public static JarClassifierBuilder<?> fromManifest(Manifest manifest) {
102    String className = manifest.getMainAttributes().getValue(CLASSIFIER_BUILDER_ATTRIBUTE_NAME);
103    JarClassifierBuilder<?> builder;
104    try {
105      builder = Class.forName(className).asSubclass(JarClassifierBuilder.class).newInstance();
106    } catch (Exception e) {
107      throw new RuntimeException("ClassifierBuilder class read from manifest does not exist", e);
108    }
109    builder.manifest = manifest;
110    return builder;
111  }
112
113  /**
114   * Loads a classifier builder from the training directory and invokes
115   * {@link #trainClassifier(File, String...)} and {@link #packageClassifier(File)}.
116   * 
117   * @param trainingDirectory
118   *          The directory where {@link #saveToTrainingDirectory(File)} has saved the model
119   *          training files.
120   * @param trainingArguments
121   *          Any additional arguments that should be passed to the classifier trainer via
122   *          {@link #trainClassifier(File, String...)}.
123   */
124  public static void trainAndPackage(File trainingDirectory, String... trainingArguments)
125      throws Exception {
126    JarClassifierBuilder<?> classifierBuilder = fromTrainingDirectory(trainingDirectory);
127    classifierBuilder.trainClassifier(trainingDirectory, trainingArguments);
128    classifierBuilder.packageClassifier(trainingDirectory);
129  }
130
131  /**
132   * Creates a new classifier builder with a default manifest.
133   */
134  public JarClassifierBuilder() {
135    super();
136    this.manifest = new Manifest();
137    Attributes attributes = this.manifest.getMainAttributes();
138    attributes.put(Attributes.Name.MANIFEST_VERSION, "1.0");
139    attributes.put(CLASSIFIER_BUILDER_ATTRIBUTE_NAME, this.getClass().getName());
140  }
141
142  /**
143   * Write all information stored in the classifier builder to the training directory. Typically
144   * called by {@link DataWriter#finish()} or {@link SequenceDataWriter#finish()}.
145   * 
146   * @param dir
147   *          The directory where classifier information should be written.
148   */
149  public void saveToTrainingDirectory(File dir) throws IOException {
150    // save the manifest to the directory
151    FileOutputStream manifestStream = new FileOutputStream(getManifestFile(dir));
152    this.manifest.write(manifestStream);
153    manifestStream.close();
154  }
155
156  /**
157   * Train a classifier from a training directory, as prepared by
158   * {@link #saveToTrainingDirectory(File)}. Typically called at the command line by
159   * {@link Train#main(String...)}.
160   * 
161   * @param dir
162   *          The directory where training data and other classifier information has been written
163   *          and where the trained classifier should be stored.
164   * @param args
165   *          Additional command line arguments for the classifier trainer.
166   */
167  public abstract void trainClassifier(File dir, String... args) throws Exception;
168
169  /**
170   * Create a classifier jar file from a directory where a classifier model was trained by
171   * {@link #trainClassifier(File, String[])}.
172   * 
173   * This method should typically not be overridden by subclasses - use
174   * {@link #packageClassifier(File, JarOutputStream)} instead.
175   * 
176   * @param dir
177   *          The directory where the classifier model was trained.
178   */
179  public void packageClassifier(File dir) throws IOException {
180    JarOutputStream modelStream = new JarOutputStream(new BufferedOutputStream(
181        new FileOutputStream(getModelJarFile(dir))), this.manifest);
182    this.packageClassifier(dir, modelStream);
183    modelStream.close();
184  }
185
186  /**
187   * Add elements to a classifier jar.
188   * 
189   * @param dir
190   *          The directory where the classifier model was trained.
191   * @param modelStream
192   *          The jar where the classifier is being written.
193   * @throws IOException
194   *           For errors reading the directory or writing to the jar.
195   */
196  protected void packageClassifier(File dir, JarOutputStream modelStream) throws IOException {
197    // Used by subclasses
198  }
199
200  private static final String MODEL_FILE_NAME = "model.jar";
201
202  /**
203   * Get the classifier jar file, as packaged by {@link #packageClassifier(File)}.
204   * 
205   * @param dir
206   *          The directory where the classifier was packaged.
207   * @return The classifier jar file.
208   */
209  public static File getModelJarFile(File dir) {
210    return new File(dir, MODEL_FILE_NAME);
211  }
212
213  public static File getModelJarFile(String directoryName) {
214    return getModelJarFile(new File(directoryName));
215  }
216
217  /**
218   * Load a classifier packaged by {@link #packageClassifier(File)} from the jar file in the
219   * training directory.
220   * 
221   * This method should typically not be overridden by subclasses - use
222   * {@link #unpackageClassifier(JarInputStream)} and {@link #newClassifier()} instead.
223   * 
224   * @param dir
225   *          The directory where the classifier was trained and packaged.
226   * @return The loaded classifier.
227   */
228  public CLASSIFIER_TYPE loadClassifierFromTrainingDirectory(File dir) throws IOException {
229    File modelJarFile = getModelJarFile(dir);
230    InputStream inputStream = new BufferedInputStream(new FileInputStream(modelJarFile));
231    try {
232      return this.loadClassifier(inputStream);
233    } finally {
234      inputStream.close();
235    }
236  }
237
238  /**
239   * Load a classifier packaged by {@link #packageClassifier(File)} from an {@link InputStream}.
240   * 
241   * This method should typically not be overridden by subclasses - use
242   * {@link #unpackageClassifier(JarInputStream)} and {@link #newClassifier()} instead.
243   * 
244   * @param inputStream
245   *          The classifier stream.
246   * @return The loaded classifier.
247   */
248  public CLASSIFIER_TYPE loadClassifier(InputStream inputStream) throws IOException {
249
250    // if it's already a Jar stream, don't re-wrap it
251    if (inputStream instanceof JarInputStream) {
252      JarInputStream modelStream = (JarInputStream) inputStream;
253      this.unpackageClassifier(modelStream);
254      return this.newClassifier();
255    }
256
257    // if we need to wrap it in a Jar stream, be sure to close the stream afterwards
258    else {
259      JarInputStream modelStream = new JarInputStream(inputStream);
260      try {
261        this.unpackageClassifier(modelStream);
262        return this.newClassifier();
263      } finally {
264        modelStream.close();
265      }
266    }
267  }
268
269  /**
270   * Load classifier elements from a classifier jar. This will typically save such attributes as
271   * instance variables in preparation for a call to {@link #newClassifier()}.
272   * 
273   * @param modelStream
274   *          The classifier jar
275   * @throws IOException
276   *           For errors reading from the jar.
277   */
278  protected void unpackageClassifier(JarInputStream modelStream) throws IOException {
279    // Used by subclasses
280  }
281
282  /**
283   * Create a new classifier using the attributes loaded by
284   * {@link #unpackageClassifier(JarInputStream)}.
285   * 
286   * @return The loaded and initialized classifier.
287   */
288  protected abstract CLASSIFIER_TYPE newClassifier();
289
290  /**
291   * Gets the standard {@link File} for a jar {@link Manifest}.
292   * 
293   * @param dir
294   *          The directory from which a jar will be built.
295   * @return The file where the manifest is expected.
296   */
297  private static File getManifestFile(File dir) {
298    return new File(dir, "MANIFEST.MF");
299  }
300}