ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/OpenMD/trunk/src/parallel/Communicator.hpp
(Generate patch)

Comparing branches/development/src/parallel/Communicator.hpp (file contents):
Revision 1544 by gezelter, Fri Mar 18 19:31:52 2011 UTC vs.
Revision 1665 by gezelter, Tue Nov 22 20:38:56 2011 UTC

# Line 43 | Line 43
43   * [1]  Meineke, et al., J. Comp. Chem. 26, 252-271 (2005).            
44   * [2]  Fennell & Gezelter, J. Chem. Phys. 124, 234104 (2006).          
45   * [3]  Sun, Lin & Gezelter, J. Chem. Phys. 128, 24107 (2008).          
46 < * [4]  Vardeman & Gezelter, in progress (2009).                        
46 > * [4]  Kuang & Gezelter,  J. Chem. Phys. 133, 164101 (2010).
47 > * [5]  Vardeman, Stocker & Gezelter, J. Chem. Theory Comput. 7, 834 (2011).
48   */
49  
50   #ifndef PARALLEL_COMMUNICATOR_HPP
# Line 53 | Line 54
54   #include <mpi.h>
55   #include "math/SquareMatrix3.hpp"
56  
57 + using namespace std;
58   namespace OpenMD{
59    
60   #ifdef IS_MPI
61  
62 <  enum direction {
63 <    Row = 0,
64 <    Column = 1
62 >  enum communicatorType {
63 >    Global = 0,
64 >    Row = 1,
65 >    Column = 2
66    };
67      
68 <  template<typename T>
69 <  struct MPITraits
70 <  {
71 <    static const MPI::Datatype datatype;
72 <    static const int dim;
68 >  template<class T>
69 >  class MPITraits {
70 >  public:
71 >    static MPI::Datatype Type();
72 >    static int Length() { return 1; };
73    };
74    
75 <  template<> const MPI::Datatype MPITraits<int>::datatype = MPI_INT;
76 <  template<> const int MPITraits<int>::dim = 1;
77 <  template<> const MPI::Datatype MPITraits<RealType>::datatype = MPI_REALTYPE;
78 <  template<> const int MPITraits<RealType>::dim = 1;
79 <  template<> const MPI::Datatype MPITraits<Vector3d>::datatype = MPI_REALTYPE;
80 <  template<> const int MPITraits<Vector3d>::dim = 3;
81 <  template<> const MPI::Datatype MPITraits<Mat3x3d>::datatype = MPI_REALTYPE;
82 <  template<> const int MPITraits<Mat3x3d>::dim = 9;
75 >  template<> inline MPI::Datatype MPITraits<int>::Type() { return MPI_INT; }
76 >  template<> inline MPI::Datatype MPITraits<RealType>::Type() { return MPI_REALTYPE; }
77 >
78 >  template<class T, unsigned int Dim>
79 >  class MPITraits< Vector<T, Dim> > {
80 >  public:
81 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
82 >    static int Length() {return Dim;}
83 >  };
84 >
85 >  template<class T>
86 >  class MPITraits< Vector3<T> > {
87 >  public:
88 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
89 >    static int Length() {return 3;}
90 >  };
91 >
92 >  template<class T, unsigned int Row, unsigned int Col>
93 >  class MPITraits< RectMatrix<T, Row, Col> > {
94 >  public:
95 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
96 >    static int Length() {return Row * Col;}
97 >  };
98 >
99 >  template<class T>
100 >  class MPITraits< SquareMatrix3<T> > {
101 >  public:
102 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
103 >    static int Length() {return 9;}
104 >  };
105    
106 <  template<direction D, typename T>
106 >  
107 >  template<communicatorType D>
108    class Communicator {
109    public:
110      
111 <    Communicator<D, T>(int nObjects) {
111 >    Communicator<D>() {
112        
113        int nProc = MPI::COMM_WORLD.Get_size();
114        int myRank = MPI::COMM_WORLD.Get_rank();
115 <
115 >      
116        int nColumnsMax = (int) sqrt(RealType(nProc));
117  
118        int nColumns;
# Line 98 | Line 124 | namespace OpenMD{
124        rowIndex_ = myRank / nColumns;      
125        columnIndex_ = myRank % nColumns;
126  
127 <      if (D == Row) {
127 >      switch(D) {
128 >      case Row :
129          myComm = MPI::COMM_WORLD.Split(rowIndex_, 0);
130 <      } else {
130 >        break;
131 >      case Column:
132          myComm = MPI::COMM_WORLD.Split(columnIndex_, 0);
133 +        break;
134 +      case Global:
135 +        myComm = MPI::COMM_WORLD.Split(myRank, 0);
136        }
106        
107      int nCommProcs = myComm.Get_size();
137  
138 <      counts.reserve(nCommProcs);
139 <      displacements.reserve(nCommProcs);
138 >    }
139 >    
140 >    MPI::Intracomm getComm() { return myComm; }
141 >    
142 >  private:
143 >    int rowIndex_;
144 >    int columnIndex_;
145 >    MPI::Intracomm myComm;
146 >  };
147 >  
148  
149 <      planSize_ = MPITraits<T>::dim * nObjects;
150 <
149 >  template<typename T>
150 >  class Plan {
151 >  public:
152 >    
153 >    Plan<T>(MPI::Intracomm comm, int nObjects) {
154 >      myComm = comm;
155 >      int nCommProcs = myComm.Get_size();
156 >      
157 >      counts.resize(nCommProcs, 0);
158 >      displacements.resize(nCommProcs, 0);
159 >      
160 >      planSize_ = MPITraits<T>::Length() * nObjects;
161 >      
162        myComm.Allgather(&planSize_, 1, MPI::INT, &counts[0], 1, MPI::INT);
163 <
163 >      
164        displacements[0] = 0;
165        for (int i = 1; i < nCommProcs; i++) {
166          displacements[i] = displacements[i-1] + counts[i-1];
119        size_ += counts[i-1];
167        }
168  
169        size_ = 0;
# Line 125 | Line 172 | namespace OpenMD{
172        }
173      }
174  
175 <
176 <    void gather(std::vector<T>& v1, std::vector<T>& v2) {
175 >    
176 >    void gather(vector<T>& v1, vector<T>& v2) {
177        
178 +      // an assert would be helpful here to make sure the vectors are the
179 +      // correct geometry
180 +      
181        myComm.Allgatherv(&v1[0],
182                          planSize_,
183 <                        MPITraits<T>::datatype,
183 >                        MPITraits<T>::Type(),
184                          &v2[0],
185                          &counts[0],
186                          &displacements[0],
187 <                        MPITraits<T>::datatype);      
188 <    }
139 <
187 >                        MPITraits<T>::Type());      
188 >    }      
189      
190 <  
191 <    void scatter(std::vector<T>& v1, std::vector<T>& v2) {
192 <
190 >    void scatter(vector<T>& v1, vector<T>& v2) {
191 >      // an assert would be helpful here to make sure the vectors are the
192 >      // correct geometry
193 >            
194        myComm.Reduce_scatter(&v1[0], &v2[0], &counts[0],
195 <                            MPITraits<T>::datatype, MPI::SUM);
195 >                            MPITraits<T>::Type(), MPI::SUM);
196      }
197 <
197 >    
198      int getSize() {
199        return size_;
200      }
201      
202    private:
203      int planSize_;     ///< how many are on local proc
154    int rowIndex_;
155    int columnIndex_;
204      int size_;
205 <    std::vector<int> counts;
206 <    std::vector<int> displacements;
205 >    vector<int> counts;
206 >    vector<int> displacements;
207      MPI::Intracomm myComm;
208    };
209  

Diff Legend

Removed lines
+ Added lines
< Changed lines
> Changed lines