001/*
002 * $Id: MPIChannel.java 5245 2015-05-01 14:03:06Z kredel $
003 */
004
005
006package edu.jas.util;
007
008
009import java.io.IOException;
010import java.util.Arrays;
011
012import mpi.Comm;
013import mpi.MPI;
014import mpi.MPIException;
015import mpi.Status;
016
017import org.apache.log4j.Logger;
018
019import edu.jas.kern.MPIEngine;
020
021
022/**
023 * MPIChannel provides a communication channel for Java objects using MPI or
024 * TCP/IP to a given rank.
025 * @author Heinz Kredel
026 */
027public final class MPIChannel {
028
029
030    private static final Logger logger = Logger.getLogger(MPIChannel.class);
031
032
033    public static final int CHANTAG = MPIEngine.TAG + 2;
034
035
036    /*
037     * Underlying MPI engine.
038     */
039    private final Comm engine; // essentially static when useTCP !
040
041
042    /*
043     * Size of Comm.
044     */
045    private final int size;
046
047
048    /*
049     * This rank.
050     */
051    private final int rank;
052
053
054    /*
055     * TCP/IP object channels with tags.
056     */
057    private static TaggedSocketChannel[] soc = null;
058
059
060    /*
061     * Transport layer.
062     * true: use TCP/IP socket layer, false: use MPI transport layer.
063     * Can not be set to false for OpenMPI Java: not working.
064     */
065    static final boolean useTCP = true;
066
067
068    /*
069     * Partner rank.
070     */
071    private final int partnerRank;
072
073
074    /*
075     * Message tag.
076     */
077    private final int tag;
078
079
080    /**
081     * Constructs a MPI channel on the given MPI engine.
082     * @param s MPI communicator object.
083     * @param r rank of MPI partner.
084     */
085    public MPIChannel(Comm s, int r) throws IOException, MPIException {
086        this(s, r, CHANTAG);
087    }
088
089
090    /**
091     * Constructs a MPI channel on the given MPI engine.
092     * @param s MPI communicator object.
093     * @param r rank of MPI partner.
094     * @param t tag for messages.
095     */
096    public MPIChannel(Comm s, int r, int t) throws IOException, MPIException {
097        engine = s;
098        rank = engine.Rank();
099        size = engine.Size();
100        if (r < 0 || size <= r) {
101            throw new IOException("r out of bounds: 0 <= r < size: " + r + ", " + size);
102        }
103        partnerRank = r;
104        tag = t;
105        synchronized (engine) {
106            if (soc == null && useTCP) {
107                int port = ChannelFactory.DEFAULT_PORT;
108                ChannelFactory cf;
109                if (rank == 0) {
110                    cf = new ChannelFactory(port);
111                    cf.init();
112                    soc = new TaggedSocketChannel[size];
113                    soc[0] = null;
114                    try {
115                        for (int i = 1; i < size; i++) {
116                            SocketChannel sc = cf.getChannel(); // TODO not correct wrt rank
117                            soc[i] = new TaggedSocketChannel(sc);
118                            soc[i].init();
119                        }
120                    } catch (InterruptedException e) {
121                        throw new IOException(e);
122                    }
123                    cf.terminate();
124                } else {
125                    cf = new ChannelFactory(port - 1); // in case of localhost
126                    soc = new TaggedSocketChannel[1];
127                    SocketChannel sc = cf.getChannel(MPIEngine.hostNames.get(0), port);
128                    soc[0] = new TaggedSocketChannel(sc);
129                    soc[0].init();
130                    cf.terminate();
131                }
132            }
133        }
134        logger.info("constructor: " + this.toString() + ", useTCP: " + useTCP);
135    }
136
137
138    /**
139     * Get the MPI engine.
140     */
141    public Comm getEngine() {
142        return engine;
143    }
144
145
146    /**
147     * Sends an object.
148     * @param v message object.
149     */
150    public void send(Object v) throws IOException, MPIException {
151        send(tag, v, partnerRank);
152    }
153
154
155    /**
156     * Sends an object.
157     * @param t message tag.
158     * @param v message object.
159     */
160    public void send(int t, Object v) throws IOException, MPIException {
161        send(t, v, partnerRank);
162    }
163
164
165    /**
166     * Sends an object.
167     * @param t message tag.
168     * @param v message object.
169     * @param pr partner rank.
170     */
171    void send(int t, Object v, int pr) throws IOException, MPIException {
172        if (useTCP) {
173            if (soc == null) {
174                logger.warn("soc not initialized: lost " + v);
175                return;
176            }
177            if (soc[pr] == null) {
178                logger.warn("soc[" + pr + "] not initialized: lost " + v);
179                return;
180            }
181            soc[pr].send(t, v);
182        } else {
183            Object[] va = new Object[] { v };
184            //synchronized (MPJEngine.class) {
185            engine.Send(va, 0, va.length, MPI.OBJECT, pr, t);
186            //}
187        }
188    }
189
190
191    /**
192     * Receives an object.
193     * @return a message object.
194     */
195    public Object receive() throws IOException, ClassNotFoundException, MPIException {
196        return receive(tag);
197    }
198
199
200    /**
201     * Receives an object.
202     * @param t message tag.
203     * @return a message object.
204     */
205    public Object receive(int t) throws IOException, ClassNotFoundException, MPIException {
206        if (useTCP) {
207            if (soc == null) {
208                logger.warn("soc not initialized");
209                return null;
210            }
211            if (soc[partnerRank] == null) {
212                logger.warn("soc[" + partnerRank + "] not initialized");
213                return null;
214            }
215            try {
216                return soc[partnerRank].receive(t);
217            } catch (InterruptedException e) {
218                throw new IOException(e);
219            }
220        }
221        Object[] va = new Object[1];
222        Status stat = null;
223        //synchronized (MPJEngine.class) {
224        stat = engine.Recv(va, 0, va.length, MPI.OBJECT, partnerRank, t);
225        //}
226        if (stat == null) {
227            throw new IOException("received null Status");
228        }
229        int cnt = stat.Get_count(MPI.OBJECT);
230        if (cnt == 0) {
231            throw new IOException("no object received");
232        }
233        if (cnt > 1) {
234            logger.warn("too many objects received, ignored " + (cnt - 1));
235        }
236        // int pr = stat.source;
237        // if (pr != partnerRank) {
238        //     logger.warn("received out of order message from " + pr);
239        // }
240        return va[0];
241    }
242
243
244    /**
245     * Closes the channel.
246     */
247    public void close() {
248        if (useTCP) {
249            if (soc == null) {
250                return;
251            }
252            for (int i = 0; i < soc.length; i++) {
253                if (soc[i] != null) {
254                    soc[i].close();
255                    soc[i] = null;
256                }
257            }
258        }
259    }
260
261
262    /**
263     * to string.
264     */
265    @Override
266    public String toString() {
267        return "MPIChannel(on=" + rank + ",to=" + partnerRank + ",tag=" + tag + "," + Arrays.toString(soc)
268                        + ")";
269    }
270
271}