#include #include #include #include #include #include #include #include #include class LinkedList { int first, nnz; const int EOL = INT_MAX-1; std::vector next; std::vector value; public: LinkedList(int num_elements) { first = EOL; nnz = 0; next.resize(num_elements,INT_MAX); value.resize(num_elements,0.0); } void Put(int i1, int i2, const std::vector & ja, const std::vector & a, double coef) { assert(Empty()); for(int i = i1; i < i2; ++i) { value[ja[i]] = a[i]*coef; next[ja[i]] = first; first = ja[i]; nnz++; } } void Add(int i1, int i2, const std::vector & ja, const std::vector & a, double coef) { for(int i = i1; i < i2; ++i) { value[ja[i]] += a[i]*coef; if( next[ja[i]] == INT_MAX ) { next[ja[i]] = first; first = ja[i]; nnz++; } } } void Sort() { std::vector cols(nnz); int i = first, k = 0; while(i != EOL) { cols[k++] = i; i = next[i]; } if( !cols.empty() ) { std::sort(cols.begin(),cols.end()); first = cols[0]; for(int q = 1; q < (int)cols.size(); ++q) next[cols[q-1]] = cols[q]; next[cols.back()] = EOL; } } void Mult(double coef) { int i = first; while( i != EOL ) { value[i] *= coef; i = next[i]; } } void Get(std::vector & ja, std::vector & a) const { int i = first; while( i != EOL ) { ja.push_back(i); a.push_back(value[i]); i = next[i]; } } void Clear() { int i = first, j; while(i != EOL) { j = next[i]; value[i] = 0.0; next[i] = INT_MAX; i = j; } first = EOL; nnz = 0; } void Print() const { int i = first; while( i != EOL ) { std::cout << "[" << i << "]: " << value[i] << std::endl; i = next[i]; } } int Size() const {return nnz;} bool Empty() const {return first == EOL;} }; class CSRMatrix { std::vector ia, ja; std::vector a; public: CSRMatrix(const std::vector & ia, const std::vector & ja, const std::vector & a) : ia(ia), ja(ja), a(a) {} // C = A * B CSRMatrix operator *(const CSRMatrix & B) const { const CSRMatrix & A = *this; int ColsA = 0, ColsB = 0; for(int i = 0; i < (int)A.ja.size(); ++i) ColsA = std::max(ColsA,A.ja[i]+1); for(int i = 0; i < (int)B.ja.size(); ++i) ColsB = std::max(ColsB,B.ja[i]+1); assert(ColsA <= B.Size()); std::vector ia(A.Size()+1), ja; std::vector a; LinkedList List(ColsB); ia[0] = 0; for(int i = 0; i < A.Size(); ++i) { for(int j = A.ia[i]; j < A.ia[i+1]; ++j) List.Add(B.ia[A.ja[j]],B.ia[A.ja[j]+1],B.ja,B.a,A.a[j]); List.Sort(); //needed for printing List.Get(ja,a); ia[i+1] = (int)ja.size(); List.Clear(); } return CSRMatrix(ia,ja,a); } void Print(int w = 5) const { int Cols = 0; for(int i = 0; i < (int)ja.size(); ++i) Cols = std::max(Cols,ja[i]+1); for(int i = 0; i < Size(); ++i) { std::cout << "|"; int k = 0; for(int j = ia[i]; j < ia[i+1]; ++j) { while(k++ < ja[j]) std::cout << std::setw(w) << " "; std::cout << std::setw(w) << a[j]; } while(k++ < Cols) std::cout << std::setw(w) << " "; std::cout << "|" << std::endl; } } void PrintContents() const { std::cout << "ia: "; for(int i = 0; i < (int)ia.size(); ++i) std::cout << ia[i] << " "; std::cout << std::endl; std::cout << "ja: "; for(int i = 0; i < (int)ja.size(); ++i) std::cout << ja[i] << " "; std::cout << std::endl; std::cout << "a: "; for(int i = 0; i < (int)a.size(); ++i) std::cout << a[i] << " "; std::cout << std::endl; } int Size() const {return (int)ia.size()-1;} }; int main(int argc, char ** argv) { //example matrices /* * matrix A (2x4) * | 0.4 1.0 | * | 0.5 0.8 0.2 | */ std::vector iaA = {0, 2, 5}; std::vector jaA = { 0, 1, 1, 2, 3}; std::vector aA = { 0.4, 1.0, 0.5, 0.8, 0.2}; CSRMatrix A(iaA,jaA,aA); std::cout << "A:" << std::endl; A.Print(); /* * matrix B (4x5) * | 0.5 1.0 | * | 0.4 0.6 0.8 | * | 0.5 1.0 2.0 0.1 | * | 0.5 3.1 | */ std::vector iaB = {0, 2, 5, 9, 11}; std::vector jaB = { 0, 3, 0, 1, 4, 1, 2, 3, 4, 3, 4}; std::vector aB = { 0.5, 1.0, 0.4, 0.6, 0.8, 0.5, 1.0, 2.0, 0.1, 0.5, 3.1}; CSRMatrix B(iaB,jaB,aB); std::cout << "B:" << std::endl; B.Print(); /* * expect matrix C (2x5) * | 0.6 0.6 0.4 0.8 | * | 0.2 0.7 0.8 1.7 1.1 | */ CSRMatrix C = A*B; std::cout << "C:" << std::endl; C.Print(); C.PrintContents(); return 0; }