ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/OpenMD/trunk/src/parallel/Communicator.hpp
(Generate patch)

Comparing branches/development/src/parallel/Communicator.hpp (file contents):
Revision 1539 by gezelter, Fri Jan 14 22:31:31 2011 UTC vs.
Revision 1601 by gezelter, Thu Aug 4 20:04:35 2011 UTC

# Line 46 | Line 46
46   * [4]  Vardeman & Gezelter, in progress (2009).                        
47   */
48  
49 < #ifndef FORCEDECOMPOSITION_COMMUNICATOR_HPP
50 < #define FORCEDECOMPOSITION_COMMUNICATOR_HPP
49 > #ifndef PARALLEL_COMMUNICATOR_HPP
50 > #define PARALLEL_COMMUNICATOR_HPP
51  
52   #include <config.h>
53   #include <mpi.h>
54   #include "math/SquareMatrix3.hpp"
55  
56 + using namespace std;
57   namespace OpenMD{
58    
59   #ifdef IS_MPI
60  
61 <  enum direction {
62 <    I = 0,
63 <    J = 1
61 >  enum communicatorType {
62 >    Global = 0,
63 >    Row = 1,
64 >    Column = 2
65    };
66      
67 <  template<typename T>
68 <  struct MPITraits
69 <  {
70 <    static const MPI::Datatype datatype;
71 <    static const int dim;
67 >  template<class T>
68 >  class MPITraits {
69 >  public:
70 >    static MPI::Datatype Type();
71 >    static int Length() { return 1; };
72    };
73    
74 <  template<> const MPI::Datatype MPITraits<int>::datatype = MPI_INT;
75 <  template<> const int MPITraits<int>::dim = 1;
76 <  template<> const MPI::Datatype MPITraits<RealType>::datatype = MPI_REALTYPE;
77 <  template<> const int MPITraits<RealType>::dim = 1;
78 <  template<> const MPI::Datatype MPITraits<Vector3d>::datatype = MPI_REALTYPE;
79 <  template<> const int MPITraits<Vector3d>::dim = 3;
80 <  template<> const MPI::Datatype MPITraits<Mat3x3d>::datatype = MPI_REALTYPE;
81 <  template<> const int MPITraits<Mat3x3d>::dim = 9;
74 >  template<> inline MPI::Datatype MPITraits<int>::Type() { return MPI_INT; }
75 >  template<> inline MPI::Datatype MPITraits<RealType>::Type() { return MPI_REALTYPE; }
76 >
77 >  template<class T, unsigned int Dim>
78 >  class MPITraits< Vector<T, Dim> > {
79 >  public:
80 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
81 >    static int Length() {return Dim;}
82 >  };
83 >
84 >  template<class T>
85 >  class MPITraits< Vector3<T> > {
86 >  public:
87 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
88 >    static int Length() {return 3;}
89 >  };
90 >
91 >  template<class T, unsigned int Row, unsigned int Col>
92 >  class MPITraits< RectMatrix<T, Row, Col> > {
93 >  public:
94 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
95 >    static int Length() {return Row * Col;}
96 >  };
97 >
98 >  template<class T>
99 >  class MPITraits< SquareMatrix3<T> > {
100 >  public:
101 >    static MPI::Datatype Type() { return MPITraits<T>::Type(); }
102 >    static int Length() {return 9;}
103 >  };
104    
105 <  template<direction D, typename T>
106 <  class Comm {
105 >  
106 >  template<communicatorType D>
107 >  class Communicator {
108    public:
109      
110 <    Comm<D, T>(int nObjects) {
110 >    Communicator<D>() {
111        
112        int nProc = MPI::COMM_WORLD.Get_size();
113        int myRank = MPI::COMM_WORLD.Get_rank();
114 <
114 >      
115        int nColumnsMax = (int) sqrt(RealType(nProc));
116  
117        int nColumns;
# Line 98 | Line 123 | namespace OpenMD{
123        rowIndex_ = myRank / nColumns;      
124        columnIndex_ = myRank % nColumns;
125  
126 <      if (D == I) {
126 >      switch(D) {
127 >      case Row :
128          myComm = MPI::COMM_WORLD.Split(rowIndex_, 0);
129 <      } else {
129 >        break;
130 >      case Column:
131          myComm = MPI::COMM_WORLD.Split(columnIndex_, 0);
132 +        break;
133 +      case Global:
134 +        myComm = MPI::COMM_WORLD.Split(myRank, 0);
135        }
106        
107      int nCommProcs = myComm.Get_size();
136  
137 <      counts.reserve(nCommProcs);
138 <      displacements.reserve(nCommProcs);
137 >    }
138 >    
139 >    MPI::Intracomm getComm() { return myComm; }
140 >    
141 >  private:
142 >    int rowIndex_;
143 >    int columnIndex_;
144 >    MPI::Intracomm myComm;
145 >  };
146 >  
147  
148 <      planSize_ = MPITraits<T>::dim * nObjects;
149 <
148 >  template<typename T>
149 >  class Plan {
150 >  public:
151 >    
152 >    Plan<T>(MPI::Intracomm comm, int nObjects) {
153 >      myComm = comm;
154 >      int nCommProcs = myComm.Get_size();
155 >      
156 >      counts.resize(nCommProcs, 0);
157 >      displacements.resize(nCommProcs, 0);
158 >      
159 >      planSize_ = MPITraits<T>::Length() * nObjects;
160 >      
161        myComm.Allgather(&planSize_, 1, MPI::INT, &counts[0], 1, MPI::INT);
162 <
162 >      
163        displacements[0] = 0;
164        for (int i = 1; i < nCommProcs; i++) {
165          displacements[i] = displacements[i-1] + counts[i-1];
166 <      }      
166 >      }
167 >
168 >      size_ = 0;
169 >      for (int i = 0; i < nCommProcs; i++) {
170 >        size_ += counts[i];
171 >      }
172      }
173  
174 <
175 <    void gather(std::vector<T>& v1, std::vector<T>& v2) {
174 >    
175 >    void gather(vector<T>& v1, vector<T>& v2) {
176        
177 +      // an assert would be helpful here to make sure the vectors are the
178 +      // correct geometry
179 +      
180        myComm.Allgatherv(&v1[0],
181                          planSize_,
182 <                        MPITraits<T>::datatype,
182 >                        MPITraits<T>::Type(),
183                          &v2[0],
184                          &counts[0],
185                          &displacements[0],
186 <                        MPITraits<T>::datatype);      
187 <    }
133 <
186 >                        MPITraits<T>::Type());      
187 >    }      
188      
189 <  
190 <    void scatter(std::vector<T>& v1, std::vector<T>& v2) {
191 <
189 >    void scatter(vector<T>& v1, vector<T>& v2) {
190 >      // an assert would be helpful here to make sure the vectors are the
191 >      // correct geometry
192 >            
193        myComm.Reduce_scatter(&v1[0], &v2[0], &counts[0],
194 <                            MPITraits<T>::datatype, MPI::SUM);
194 >                            MPITraits<T>::Type(), MPI::SUM);
195      }
196      
197 +    int getSize() {
198 +      return size_;
199 +    }
200 +    
201    private:
202      int planSize_;     ///< how many are on local proc
203 <    int rowIndex_;
204 <    int columnIndex_;
205 <    std::vector<int> counts;
147 <    std::vector<int> displacements;
203 >    int size_;
204 >    vector<int> counts;
205 >    vector<int> displacements;
206      MPI::Intracomm myComm;
207    };
208  

Diff Legend

Removed lines
+ Added lines
< Changed lines
> Changed lines