001/*
002 * Copyright 2005-2007 The Kuali Foundation
003 *
004 *
005 * Licensed under the Educational Community License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.opensource.org/licenses/ecl2.php
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.kualigan.tools.ant.tasks;
018
019import java.io.PrintStream;
020import java.io.Reader;
021
022import java.lang.reflect.Field;
023
024import java.sql.Blob;
025import java.sql.Clob;
026import java.sql.Connection;
027import java.sql.DriverManager;
028import java.sql.DatabaseMetaData;
029import java.sql.PreparedStatement;
030import java.sql.SQLException;
031import java.sql.ResultSet;
032import java.sql.ResultSetMetaData;
033import java.sql.Statement;
034import java.sql.Types;
035
036import java.util.ArrayList;
037import java.util.Collection;
038import java.util.HashMap;
039import java.util.Map;
040import java.util.Observable;
041import java.util.Observer;
042
043import org.apache.tools.ant.BuildException;
044import org.apache.tools.ant.DirectoryScanner;
045import org.apache.tools.ant.Main;
046import org.apache.tools.ant.Task;
047import org.apache.tools.ant.types.FileSet;
048
049import static org.apache.tools.ant.Project.MSG_DEBUG;
050
051/**
052 *
053 * @author Leo Przybylski (przybyls@arizona.edu)
054 */
055public class MigrateData extends Task {
056    
057    private static final String[] carr = new String[] {"|", "\\", "-", "/"};
058    private static final String RECORD_COUNT_QUERY = "select count(*) as \"COUNT\" from %s";
059    private static final String SELECT_ALL_QUERY   = "select * from %s";
060    private static final String INSERT_STATEMENT   = "insert into %s (%s) values (%s)";
061    private static final String DATE_CONVERSION    = "TO_DATE('%s', 'YYYYMMDDHH24MISS')";
062    private static final String COUNT_FIELD        = "COUNT";
063    private static final String LIQUIBASE_TABLE    = "DATABASECHANGELOG";
064    private static final int[]  QUOTED_TYPES       = 
065        new int[] {Types.CHAR, Types.VARCHAR, Types.TIME, Types.LONGVARCHAR, Types.DATE, Types.TIMESTAMP};
066
067    private static final String HSQLDB_PUBLIC      = "PUBLIC";
068    private static final int    MAX_THREADS        = 3;
069
070    
071    private String source;
072    private String target;
073    private int threadCount;
074
075    public MigrateData() { 
076        int threadCount = 1;
077    }
078
079    
080    public void setSource(String refid) {
081        this.source = refid;
082    }
083    
084    public String getSource() {
085        return this.source;
086    }
087
088    public void setTarget(String refid) {
089        this.target = refid;
090    }
091
092    public String getTarget() {
093        return this.target;
094    }
095    
096    public void execute() {
097        final RdbmsConfig source = (RdbmsConfig) getProject().getReference(getSource());
098        final RdbmsConfig target = (RdbmsConfig) getProject().getReference(getTarget());
099
100        log("Migrating data from " + source.getUrl() + " to " + target.getUrl());
101
102        final Incrementor recordCountIncrementor = new Incrementor();
103        final Map<String, Integer> tableData = getTableData(source, target, recordCountIncrementor);
104
105        log("Copying " + tableData.size() + " tables");
106
107        float recordVisitor = 0;
108        final ProgressObserver progressObserver = new ProgressObserver(recordCountIncrementor.getValue(),
109                                                                       48f, 48f/100,
110                                                                       "\r|%s[%s] %3d%% (%d/%d) records");
111        final ProgressObservable observable = new ProgressObservable();
112        observable.addObserver(progressObserver);
113
114        final ThreadGroup tgroup = new ThreadGroup("Migration Threads");
115
116        for (final String tableName : tableData.keySet()) {
117            debug("Migrating table " + tableName + " with " + tableData.get(tableName) + " records");
118            /*
119            if (tgroup.activeCount() < MAX_THREADS) {
120                new Thread(tgroup, new Runnable() {
121                        public void run() {
122                            migrate(source, target, tableName, observable);
123                        }
124                    }).start();
125            }
126            else {
127            */
128                final Map<String,Integer> columns = new HashMap<String, Integer>();
129                migrate(source, target, tableName, observable);
130                // }
131        }
132
133        // Wait for other threads to finish
134        try {
135            while(tgroup.activeCount() > 0) {
136                Thread.sleep(5000);
137            }
138        }
139        catch (InterruptedException e) {
140        }
141
142        try {
143            final Connection targetDb = openConnection(target);
144            if (targetDb.getMetaData().getDriverName().toLowerCase().contains("hsqldb")) {
145                Statement st = targetDb.createStatement();
146                st.execute("CHECKPOINT"); 
147                st.close();
148            }
149            targetDb.close();
150        }
151        catch (Exception e) {
152            throw new BuildException(e);
153        }
154    }
155
156    protected void migrate(final RdbmsConfig source, 
157                           final RdbmsConfig target, 
158                           final String tableName, 
159                           final ProgressObservable observable) {
160        final Connection sourceDb = openConnection(source);
161        final Connection targetDb = openConnection(target);
162        source.setConnection(sourceDb);
163        target.setConnection(targetDb);
164        final Map<String, Integer> columns = getColumnMap(source, target, tableName);
165
166        if (columns.size() < 1) {
167            log("Columns are empty for " + tableName);
168            return;
169        }
170
171        PreparedStatement toStatement = prepareStatement(targetDb, tableName, columns);
172        Statement fromStatement = null;
173
174        final boolean hasClob = columns.values().contains(Types.CLOB);
175        int recordsLost = 0;
176        
177        try {
178            fromStatement = sourceDb.createStatement();
179            final ResultSet results = fromStatement.executeQuery(String.format(SELECT_ALL_QUERY, tableName));
180            
181            try {
182                while (results.next()) {
183                    try {
184                        toStatement.clearParameters();
185                        
186                        int i = 1;
187                        for (String columnName : columns.keySet()) {
188                            final Object value = results.getObject(columnName);
189                            
190                            if (value != null) {
191                                try {
192                                    handleLob(toStatement, value, i);
193                                }
194                                catch (Exception e) {
195                                    System.err.println(String.format("Error processing %s.%s %s", tableName, columnName, columns.get(columnName)));
196                                    if (Clob.class.isAssignableFrom(value.getClass())) {
197                                        System.err.println("Got exception trying to insert CLOB with length" + ((Clob) value).length());
198                                    }
199                                    e.printStackTrace();
200                                }
201                            }
202                            else {
203                                toStatement.setObject(i,value);
204                            } 
205                            i++;
206                        }
207                        
208                        boolean retry = true;
209                        int retry_count = 0;
210                        while(retry) {
211                            try {
212                                toStatement.execute();
213                                retry = false;
214                            }
215                            catch (SQLException sqle) {
216                                retry = false;
217                                if (sqle.getMessage().contains("ORA-00942")) {
218                                    log("Couldn't find " + tableName);
219                                    log("Tried insert statement " + getStatementBuffer(tableName, columns));
220                                    // sqle.printStackTrace();
221                                }
222                                else if (sqle.getMessage().contains("ORA-12519")) {
223                                    retry = true;
224                                    log("Tried insert statement " + getStatementBuffer(tableName, columns));
225                                    sqle.printStackTrace();
226                                }
227                                else if (sqle.getMessage().contains("IN or OUT")) {
228                                    log("Column count was " + columns.keySet().size());
229                                }
230                                else if (sqle.getMessage().contains("Error reading")) {
231                                    if (retry_count > 5) {
232                                        log("Tried insert statement " + getStatementBuffer(tableName, columns));
233                                        retry = false;
234                                    }
235                                    retry_count++;
236                                }
237                                else {
238                                    sqle.printStackTrace();
239                                }
240                            }
241                        }
242                    }
243                    catch (Exception e) {
244                        recordsLost++;
245                        throw e;
246                    }
247                    finally {
248                        observable.incrementRecord();
249                    }
250                }
251            }
252            finally {
253                if (results != null) {
254                    try {
255                        results.close();
256                    }
257                    catch(Exception e) {
258                    }
259                }
260            }
261        }
262        catch (Exception e) {
263            throw new BuildException(e);
264        }
265        finally {
266            if (sourceDb != null) {
267                try {
268                    if (sourceDb.getMetaData().getDriverName().toLowerCase().contains("hsqldb")) {
269                        Statement st = sourceDb.createStatement();
270                        st.execute("CHECKPOINT"); 
271                        st.close();
272                    }
273                    fromStatement.close();
274                    sourceDb.close();
275                }
276                catch (Exception e) {
277                }
278            }
279
280            if (targetDb != null) {
281                try {
282                    targetDb.commit();
283                    if (targetDb.getMetaData().getDriverName().toLowerCase().contains("hsql")) {
284                        Statement st = targetDb.createStatement();
285                        st.execute("CHECKPOINT"); 
286                        st.close();
287                    }
288                    toStatement.close();
289                    targetDb.close();
290                }
291                catch (Exception e) {
292                    log("Error closing database connection");
293                    e.printStackTrace();
294                }
295            }
296            debug("Lost " +recordsLost + " records");
297            columns.clear();
298        }
299    }
300
301    protected void handleLob(final PreparedStatement toStatement, final Object value, final int i) throws SQLException {
302        if (Clob.class.isAssignableFrom(value.getClass())) {
303            toStatement.setAsciiStream(i, ((Clob) value).getAsciiStream(), ((Clob) value).length());
304        }
305        else if (Blob.class.isAssignableFrom(value.getClass())) {
306            toStatement.setBinaryStream(i, ((Blob) value).getBinaryStream(), ((Blob) value).length());
307        }
308        else {
309            toStatement.setObject(i,value);
310        } 
311    }
312
313    protected PreparedStatement prepareStatement(Connection conn, String tableName, Map<String, Integer> columns) {
314        final String statement = getStatementBuffer(tableName, columns);
315        
316        try {
317            return conn.prepareStatement(statement);
318        }
319        catch (Exception e) {
320            throw new BuildException(e);
321        }
322    }
323
324    private String getStatementBuffer(String tableName, Map<String,Integer> columns) {
325        String retval = null;
326
327        final StringBuilder names  = new StringBuilder();
328        final StringBuilder values = new StringBuilder();
329        for (String columnName : columns.keySet()) {
330            names.append(columnName).append(",");
331            values.append("?,");
332        }
333
334        names.setLength(names.length() - 1);
335        values.setLength(values.length() - 1);
336        retval = String.format(INSERT_STATEMENT, tableName, names, values);
337        
338
339        return retval;
340    }
341
342    protected boolean isValidTable(final DatabaseMetaData metadata, final String tableName) {
343        return !(tableName.startsWith("BIN$") || tableName.toUpperCase().startsWith(LIQUIBASE_TABLE) || isSequence(metadata, tableName));
344    }
345
346    protected boolean isSequence(final DatabaseMetaData metadata, final String tableName) {
347        final RdbmsConfig source = (RdbmsConfig) getProject().getReference(getSource());
348        try {
349            final ResultSet rs = metadata.getColumns(null, source.getSchema(), tableName, null);
350            int columnCount = 0;
351            boolean hasId = false;
352            try {
353                while (rs.next()) {
354                    columnCount++;
355                    if ("yes".equalsIgnoreCase(rs.getString("IS_AUTOINCREMENT"))) {
356                        hasId = true;
357                    }
358                }
359            }
360            finally {
361                if (rs != null) {
362                    try {
363                        rs.close();
364                    }
365                    catch (Exception e) {
366                    }
367                }
368                return (columnCount == 1 && hasId);
369            }
370        }
371        catch (Exception e) {
372            return false;
373        }
374    }
375
376    /**
377     * Get a list of table names available mapped to row counts
378     */
379    protected Map<String, Integer> getTableData(RdbmsConfig source, RdbmsConfig target, Incrementor incrementor) {
380        Connection sourceConn = openConnection(source);
381        Connection targetConn = openConnection(target);
382        final Map<String, Integer> retval = new HashMap<String, Integer>();
383        final Collection<String> toRemove = new ArrayList<String>();
384
385        debug("Looking up table names");
386        try {
387            final DatabaseMetaData metadata = sourceConn.getMetaData();
388            final ResultSet tableResults = metadata.getTables(sourceConn.getCatalog(), source.getSchema(), null, new String[] { "TABLE" });
389            while (tableResults.next()) {
390                final String tableName = tableResults.getString("TABLE_NAME");
391                if (!isValidTable(metadata, tableName)) {
392                    continue;
393                }
394                if (tableName.toUpperCase().startsWith(LIQUIBASE_TABLE)) continue;
395                final int rowCount = getTableRecordCount(sourceConn, tableName);
396                if (rowCount < 1) { // no point in going through tables with no data
397                    
398                }
399                incrementor.increment(rowCount);
400                debug("Adding table " + tableName);
401                retval.put(tableName, rowCount);
402            }
403            tableResults.close();
404        }
405        catch (Exception e) {
406            throw new BuildException(e);
407        }
408        finally {
409            if (sourceConn != null) {
410                try {
411                    sourceConn.close();
412                    sourceConn = null;
413                }
414                catch (Exception e) {
415                }
416            }
417        }
418
419        try {
420            for (String tableName : retval.keySet()) {
421                final ResultSet tableResults = targetConn.getMetaData().getTables(targetConn.getCatalog(), target.getSchema(), null, new String[] { "TABLE" });
422                if (!tableResults.next()) {
423                    log("Removing " + tableName);
424                    toRemove.add(tableName);
425                }
426                tableResults.close();
427            }
428        }
429        catch (Exception e) {
430            throw new BuildException(e);
431        }
432        finally {
433            if (targetConn != null) {
434                try {
435                    targetConn.close();
436                    targetConn = null;
437                }
438                catch (Exception e) {
439                }
440            }
441        }
442
443        for (String tableName : toRemove) {
444            retval.remove(tableName);
445        }
446        
447        return retval;
448    }
449
450    private Map<String, Integer> getColumnMap(final RdbmsConfig source, final RdbmsConfig target, String tableName) {
451        final Connection targetDb = target.getConnection();
452        final Connection sourceDb = source.getConnection();
453        final Map<String,Integer> retval = new HashMap<String,Integer>();
454        final Collection<String> toRemove = new ArrayList<String>();
455        try {
456            final Statement state = targetDb.createStatement();                
457            final ResultSet altResults = state.executeQuery("select * from " + tableName + " where 1 = 0");
458            final ResultSetMetaData metadata = altResults.getMetaData();
459            
460            for (int i = 1; i <= metadata.getColumnCount(); i++) {
461                retval.put(metadata.getColumnName(i),
462                           metadata.getColumnType(i));
463            }
464            altResults.close();
465            state.close();
466        }
467        catch (Exception e) {
468            throw new BuildException(e);
469        }
470
471        for (final String column : retval.keySet()) {
472            try {
473                final Statement state = targetDb.createStatement();                
474                final ResultSet altResults = state.executeQuery("select * from " + tableName + " where 1 = 0");
475                final ResultSetMetaData metadata = altResults.getMetaData();
476
477                for (int i = 1; i <= metadata.getColumnCount(); i++) {
478                    retval.put(metadata.getColumnName(i),
479                               metadata.getColumnType(i));
480                }
481                altResults.close();
482                state.close();
483            }
484            catch (Exception e) {
485                throw new BuildException(e);
486            }
487        }
488
489        for (final String column : toRemove) {
490            retval.remove(column);
491        }
492        
493        return retval;
494    }
495
496    private int getTableRecordCount(Connection conn, String tableName) {
497        final String query = String.format(RECORD_COUNT_QUERY, tableName);
498        Statement statement = null;
499        try {
500            statement = conn.createStatement();
501            final ResultSet results = statement.executeQuery(query);
502            results.next();
503            final int retval = results.getInt(COUNT_FIELD);
504            results.close();
505            return retval;
506        }
507        catch (Exception e) {
508            if (e.getMessage().contains("ORA-00942")) {
509                log("Couldn't find " + tableName);
510                log("Tried insert statement " + query);
511            }
512            log("Exception executing " + query);
513            throw new BuildException(e);
514        }
515        finally {
516            try {
517                if (statement != null) {
518                    statement.close();
519                    statement = null;
520                }
521            }
522            catch (Exception e) {
523            }
524        }
525    }
526
527    private void debug(String msg) {
528        log(msg, MSG_DEBUG);
529    }
530
531    private Connection openSource() {
532        return openConnection(getSource());
533    }
534
535    private Connection openTarget() {
536        return openConnection(getTarget());
537    }
538
539    private Connection openConnection(String reference) {
540        final RdbmsConfig config = (RdbmsConfig) getProject().getReference(reference);
541        return openConnection(config);
542    }
543    
544    private Connection openConnection(RdbmsConfig config) {
545        Connection retval = null;
546        
547        while (retval == null) {
548            try {
549                debug("Loading schema " + config.getSchema() + " at url " + config.getUrl());
550                Class.forName(config.getDriver());
551
552                retval = DriverManager.getConnection(config.getUrl(), config.getUsername(), config.getPassword());
553                retval.setAutoCommit(false);
554
555
556                // If this is an HSQLDB database, then we probably want to turn off logging for permformance
557                if (config.getDriver().indexOf("hsqldb") > -1) {
558                    debug("Disabling hsqldb log");
559                    final Statement st = retval.createStatement();
560                    st.execute("SET FILES LOG FALSE");
561                    st.close();
562                }
563                
564            }
565            catch (Exception e) {
566                // throw new BuildException(e);
567            }
568        }
569        
570        return retval;
571    }
572
573    /**
574     * Helper class for incrementing values
575     */
576    private class Incrementor {
577        private int value;
578        
579        public Incrementor() {
580            value = 0;
581        }
582        
583        public int getValue() {
584            return value;
585        }
586
587        public void increment() {
588            value++;
589        }
590
591        public void increment(int by) {
592            value += by;
593        }
594    }
595
596    private class ProgressObservable extends Observable {
597        public void incrementRecord() {
598            setChanged();
599            notifyObservers();
600            clearChanged();
601        }
602    }
603
604    /**
605     * Observer for handling progress
606     * 
607     */
608    private class ProgressObserver implements Observer {
609
610        private float total;
611        private float progress;
612        private float length;
613        private float ratio;
614        private String template;
615        private float count;
616        private PrintStream out;
617        
618        public ProgressObserver(final float total,
619                                final float length,
620                                final float ratio,
621                                final String template) {
622            this.total    = total;
623            this.template = template;
624            this.ratio    = ratio;
625            this.length   = length;
626            this.count    = 0;
627            
628            try {
629                final Field field = Main.class.getDeclaredField("out");
630                field.setAccessible(true);
631                out = (PrintStream) field.get(null);
632            }
633            catch (Exception e) {
634                e.printStackTrace();
635            }
636        }
637
638        public synchronized void update(Observable o, Object arg) {
639            count++;
640
641            final int percent = (int) ((count / total) * 100f);
642            final int progress = (int) ((count / total) * (100f * ratio));
643            final StringBuilder progressBuffer = new StringBuilder();
644                
645            for (int x = 0; x < progress; x++) {
646                progressBuffer.append('=');
647            }
648            
649            for (int x = progress; x < length; x++) {
650                progressBuffer.append(' ');
651            }
652            int roll = (int) (count / (total / 1000));
653
654            if (getProject().getProperty("run_from_ant") == null) {
655                out.print(String.format(template, progressBuffer, carr[roll % carr.length], percent, (int) count, (int) total));
656            }
657            else if ((count % 5000) == 0 || count == total) {
658                out.println(String.format("(%s)%% %s of %s records", (int) ((count / total) * 100), (int) count, (int) total));
659            }
660        }
661    }
662}