001/*
002 * $Id: DistHashTableMPI.java 5416 2016-01-27 23:07:44Z kredel $
003 */
004
005package edu.jas.util;
006
007
008import java.io.IOException;
009import java.util.AbstractMap;
010import java.util.ArrayList;
011import java.util.Collection;
012import java.util.Iterator;
013import java.util.List;
014import java.util.Set;
015import java.util.SortedMap;
016import java.util.TreeMap;
017
018import mpi.Comm;
019import mpi.MPI;
020import mpi.MPIException;
021import mpi.Status;
022
023import org.apache.log4j.Logger;
024
025import edu.jas.kern.MPIEngine;
026
027
028/**
029 * Distributed version of a HashTable using MPI. Implemented with a SortedMap /
030 * TreeMap to keep the sequence order of elements. Implemented using MPI
031 * transport or TCP transport.
032 * @author Heinz Kredel
033 */
034
035public class DistHashTableMPI<K, V> extends AbstractMap<K, V> {
036
037
038    private static final Logger logger = Logger.getLogger(DistHashTableMPI.class);
039
040
041    private static final boolean debug = logger.isDebugEnabled();
042
043
044    /*
045     * Backing data structure.
046     */
047    protected final SortedMap<K, V> theList;
048
049
050    /*
051     * Thread for receiving pairs.
052     */
053    protected DHTMPIListener<K, V> listener;
054
055
056    /*
057     * MPI communicator.
058     */
059    protected final Comm engine;
060
061
062    /*
063     * Size of Comm.
064     */
065    private final int size;
066
067
068    /*
069     * This rank.
070     */
071    private final int rank;
072
073
074    /**
075     * Message tag for DHT communicaton.
076     */
077    public static final int DHTTAG = MPIEngine.TAG + 1;
078
079
080    /*
081     * TCP/IP object channels.
082     */
083    private final SocketChannel[] soc;
084
085
086    /**
087     * Transport layer.
088     * true: use TCP/IP socket layer, false: use MPI transport layer.
089     */
090    static final boolean useTCP = false;
091
092
093    /**
094     * DistHashTableMPI.
095     */
096    public DistHashTableMPI() throws MPIException, IOException {
097        this(MPIEngine.getCommunicator());
098    }
099
100
101    /**
102     * DistHashTableMPI.
103     * @param args command line for MPI runtime system.
104     */
105    public DistHashTableMPI(String[] args) throws MPIException, IOException {
106        this(MPIEngine.getCommunicator(args));
107    }
108
109
110    /**
111     * DistHashTableMPI.
112     * @param cm MPI communicator to use.
113     */
114    public DistHashTableMPI(Comm cm) throws MPIException, IOException {
115        engine = cm;
116        rank = engine.Rank();
117        size = engine.Size();
118        if (useTCP) { // && soc == null
119            int port = ChannelFactory.DEFAULT_PORT + 11;
120            ChannelFactory cf;
121            if (rank == 0) {
122                cf = new ChannelFactory(port); 
123                cf.init();
124                soc = new SocketChannel[size];
125                soc[0] = null;
126                try {
127                    for (int i = 1; i < size; i++) {
128                        SocketChannel sc = cf.getChannel(); // TODO not correct wrt rank
129                        soc[i] = sc;
130                    }
131                } catch (InterruptedException e) {
132                    throw new IOException(e);
133                }
134                cf.terminate();
135            } else {
136                cf = new ChannelFactory(port-1); // in case of localhost
137                soc = new SocketChannel[1];
138                SocketChannel sc = cf.getChannel(MPIEngine.hostNames.get(0), port);
139                soc[0] = sc;
140                cf.terminate();
141            }
142        } else {
143            soc = null;
144        }
145        theList = new TreeMap<K, V>();
146        //theList = new ConcurrentSkipListMap<K, V>(); // Java 1.6
147        listener = new DHTMPIListener<K, V>(engine, soc, theList);
148        logger.info("constructor: " + rank + "/" + size + ", useTCP: " + useTCP);
149    }
150
151
152    /**
153     * Hash code.
154     */
155    @Override
156    public int hashCode() {
157        return theList.hashCode();
158    }
159
160
161    /**
162     * Equals.
163     */
164    @Override
165    public boolean equals(Object o) {
166        return theList.equals(o);
167    }
168
169
170    /**
171     * Contains key.
172     */
173    @Override
174    public boolean containsKey(Object o) {
175        return theList.containsKey(o);
176    }
177
178
179    /**
180     * Contains value.
181     */
182    @Override
183    public boolean containsValue(Object o) {
184        return theList.containsValue(o);
185    }
186
187
188    /**
189     * Get the values as Collection.
190     */
191    @Override
192    public Collection<V> values() {
193        synchronized (theList) {
194            return new ArrayList<V>(theList.values());
195        }
196    }
197
198
199    /**
200     * Get the keys as set.
201     */
202    @Override
203    public Set<K> keySet() {
204        synchronized (theList) {
205            return theList.keySet();
206        }
207    }
208
209
210    /**
211     * Get the entries as Set.
212     */
213    @Override
214    public Set<Entry<K, V>> entrySet() {
215        synchronized (theList) {
216            return theList.entrySet();
217        }
218    }
219
220
221    /**
222     * Get the internal list, convert from Collection.
223     */
224    public List<V> getValueList() {
225        synchronized (theList) {
226            return new ArrayList<V>(theList.values());
227        }
228    }
229
230
231    /**
232     * Get the internal sorted map. For synchronization purpose in normalform.
233     */
234    public SortedMap<K, V> getList() {
235        return theList;
236    }
237
238
239    /**
240     * Size of the (local) list.
241     */
242    @Override
243    public int size() {
244        synchronized (theList) {
245            return theList.size();
246        }
247    }
248
249
250    /**
251     * Is the List empty?
252     */
253    @Override
254    public boolean isEmpty() {
255        synchronized (theList) {
256            return theList.isEmpty();
257        }
258    }
259
260
261    /**
262     * List key iterator.
263     */
264    public Iterator<K> iterator() {
265        synchronized (theList) {
266            return theList.keySet().iterator();
267        }
268    }
269
270
271    /**
272     * List value iterator.
273     */
274    public Iterator<V> valueIterator() {
275        synchronized (theList) {
276            return theList.values().iterator();
277        }
278    }
279
280
281    /**
282     * Put object to the distributed hash table. Blocks until the key value pair
283     * is send and received from the server.
284     * @param key
285     * @param value
286     */
287    public void putWait(K key, V value) {
288        put(key, value); // = send
289        // assume key does not change multiple times before test:
290        V val = null;
291        do {
292            val = getWait(key);
293            //System.out.print("#");
294        } while (!value.equals(val));
295    }
296
297
298    /**
299     * Put object to the distributed hash table. Returns immediately after
300     * sending does not block.
301     * @param key
302     * @param value
303     */
304    @Override
305    public V put(K key, V value) {
306        if (key == null || value == null) {
307            throw new NullPointerException("null keys or values not allowed");
308        }
309        try {
310            DHTTransport<K, V> tc = DHTTransport.<K, V> create(key, value);
311            for (int i = 1; i < size; i++) { // send not to self.listener
312                if (useTCP) {
313                    soc[i].send(tc);
314                } else {
315                    DHTTransport[] tcl = new DHTTransport[] { tc };
316                    synchronized (MPIEngine.class) { // do not remove
317                        engine.Send(tcl, 0, tcl.length, MPI.OBJECT, i, DHTTAG);
318                    }
319                }
320            }
321            synchronized (theList) { // add to self.listener
322                theList.put(key, value); //avoid seri: tc.key(), tc.value());
323                theList.notifyAll();
324            }
325            if (debug) {
326                K k = tc.key();
327                if (!key.equals(k)) {
328                    logger.warn("deserial(serial)) != key: " + key + " != " + k);
329                }
330                V v = tc.value();
331                if (!value.equals(v)) {
332                    logger.warn("deserial(serial)) != value: " + value + " != " + v);
333                }
334
335            }
336            //System.out.println("send: "+tc);
337        } catch (ClassNotFoundException e) {
338            logger.info("sending(key=" + key + ")");
339            logger.warn("send " + e);
340            e.printStackTrace();
341        } catch (MPIException e) {
342            logger.info("sending(key=" + key + ")");
343            logger.warn("send " + e);
344            e.printStackTrace();
345        } catch (IOException e) {
346            logger.info("sending(key=" + key + ")");
347            logger.warn("send " + e);
348            e.printStackTrace();
349        }
350        return null;
351    }
352
353
354    /**
355     * Get value under key from DHT. Blocks until the object is send and
356     * received from the server (actually it blocks until some value under key
357     * is received).
358     * @param key
359     * @return the value stored under the key.
360     */
361    public V getWait(K key) {
362        V value = null;
363        try {
364            synchronized (theList) {
365                value = theList.get(key);
366                while (value == null) {
367                    //System.out.print("-");
368                    theList.wait(100);
369                    value = theList.get(key);
370                }
371            }
372        } catch (InterruptedException e) {
373            //Thread.currentThread().interrupt();
374            e.printStackTrace();
375            return value;
376        }
377        return value;
378    }
379
380
381    /**
382     * Get value under key from DHT. If no value is jet available null is
383     * returned.
384     * @param key
385     * @return the value stored under the key.
386     */
387    @Override
388    public V get(Object key) {
389        synchronized (theList) {
390            return theList.get(key);
391        }
392    }
393
394
395    /**
396     * Clear the List. Caveat: must be called on all clients.
397     */
398    @Override
399    public void clear() {
400        // send clear message to others
401        synchronized (theList) {
402            theList.clear();
403        }
404    }
405
406
407    /**
408     * Initialize and start the list thread.
409     */
410    public void init() {
411        logger.info("init " + listener + ", theList = " + theList);
412        if (listener == null) {
413            return;
414        }
415        if (listener.isDone()) {
416            return;
417        }
418        if (debug) {
419            logger.debug("initialize " + listener);
420        }
421        synchronized (theList) {
422            listener.start();
423        }
424    }
425
426
427    /**
428     * Terminate the list thread.
429     */
430    public void terminate() {
431        if (listener == null) {
432            return;
433        }
434        if (debug) {
435            Runtime rt = Runtime.getRuntime();
436            logger.debug("terminate " + listener + ", runtime = " + rt.hashCode());
437        }
438        listener.setDone();
439        DHTTransport<K, V> tc = new DHTTransportTerminate<K, V>();
440        try {
441            if (rank == 0) {
442                //logger.info("send(" + rank + ") terminate");
443                for (int i = 1; i < size; i++) { // send not to self.listener
444                    if (useTCP) {
445                        soc[i].send(tc);
446                    } else {
447                        DHTTransport[] tcl = new DHTTransport[] { tc };
448                        synchronized (MPIEngine.class) { // do not remove
449                            engine.Send(tcl, 0, tcl.length, MPI.OBJECT, i, DHTTAG);
450                        }
451                    }
452                }
453            }
454        } catch (MPIException e) {
455            logger.info("sending(terminate)");
456            logger.info("send " + e);
457            e.printStackTrace();
458        } catch (IOException e) {
459            logger.info("sending(terminate)");
460            logger.info("send " + e);
461            e.printStackTrace();
462        }
463        try {
464            while (listener.isAlive()) {
465                //System.out.print("+++++");
466                listener.join(999);
467                listener.interrupt();
468            }
469        } catch (InterruptedException e) {
470            //Thread.currentThread().interrupt();
471        }
472        listener = null;
473    }
474
475}
476
477
478/**
479 * Thread to comunicate with the other DHT lists.
480 */
481class DHTMPIListener<K, V> extends Thread {
482
483
484    private static final Logger logger = Logger.getLogger(DHTMPIListener.class);
485
486
487    private static final boolean debug = logger.isDebugEnabled();
488
489
490    private final Comm engine;
491
492
493    private final SortedMap<K, V> theList;
494
495
496    private final SocketChannel[] soc;
497
498
499    private boolean goon;
500
501
502    /**
503     * Constructor.
504     */
505    DHTMPIListener(Comm cm, SocketChannel[] s, SortedMap<K, V> list) {
506        engine = cm;
507        theList = list;
508        goon = true;
509        soc = s;
510    }
511
512
513    /**
514     * Test if done.
515     */
516    boolean isDone() {
517        return !goon;
518    }
519
520
521    /**
522     * Set to done status.
523     */
524    void setDone() {
525        goon = false;
526    }
527
528
529    /**
530     * run.
531     */
532    @SuppressWarnings("unchecked")
533    @Override
534    public void run() {
535        logger.info("listener run() " + this);
536        int rank = -1;
537        DHTTransport<K, V> tc;
538        //goon = true;
539        while (goon) {
540            tc = null;
541            try {
542                if (rank < 0) {
543                    rank = engine.Rank();
544                }
545                if (rank == 0) {
546                    logger.info("listener on rank 0 stopped");
547                    goon = false;
548                    continue;
549                }
550                Object to = null;
551                if (DistHashTableMPI.useTCP) {
552                    to = soc[0].receive();
553                } else {
554                    DHTTransport[] tcl = new DHTTransport[1];
555                    Status stat = null;
556                    synchronized (MPIEngine.class) { // do not remove global static lock, // only from 0:
557                        stat = engine.Recv(tcl, 0, tcl.length, MPI.OBJECT, 0,             //MPI.ANY_SOURCE,
558                                        DistHashTableMPI.DHTTAG);
559                    }
560                    //logger.info("waitRequest done: stat = " + stat);
561                    if (stat == null) {
562                        goon = false;
563                        break;
564                    }
565                    int cnt = stat.Get_count(MPI.OBJECT);
566                    if (cnt == 0) {
567                        goon = false;
568                        break;
569                    } else if (cnt > 1) {
570                        logger.warn("ignoring " + (cnt - 1) + " received objects");
571                    }
572                    to = tcl[0];
573                }
574                tc = (DHTTransport<K, V>) to;
575                if (debug) {
576                    logger.debug("receive(" + tc + ")");
577                }
578                if (tc instanceof DHTTransportTerminate) {
579                    logger.info("receive(" + rank + ") terminate");
580                    goon = false;
581                    break;
582                }
583                if (this.isInterrupted()) {
584                    goon = false;
585                    break;
586                }
587                K key = tc.key();
588                if (key != null) {
589                    logger.info("receive(" + rank + "), key=" + key);
590                    V val = tc.value();
591                    synchronized (theList) {
592                        theList.put(key, val);
593                        theList.notifyAll();
594                    }
595                }
596            } catch (MPIException e) {
597                goon = false;
598                logger.warn("receive(MPI) " + e);
599                //e.printStackTrace();
600            } catch (ClassNotFoundException e) {
601                goon = false;
602                logger.info("receive(Class) " + e);
603                e.printStackTrace();
604            } catch (Exception e) {
605                goon = false;
606                logger.info("receive " + e);
607                e.printStackTrace();
608            }
609        }
610        logger.info("terminated at " + rank);
611    }
612
613}