#define __TBB_PREVIEW_MUTEXES 1 //#define __TBB_cache_aligned_allocator_H 1 //#define __TBB_flow_graph_H 1 #define __TBB_mutex_H 1 #include "tensor/cpu_row_sparse_tensor.h" #include "tensor/t_data.h" #include "tensor/cpu_dense_tensor.h" #include "tensor/cpu_sparse_tensor.h" #include "tensor/mkl_helper.h" #include "util/mem_holder.h" #include #include //#include #include #include #include #include using namespace oneapi; namespace gnn { template TensorTemplate::TensorTemplate() : data(nullptr), is_full(false) { row_idxes.Reshape({0}); } template void TensorTemplate::Reshape(std::vector l) { this->shape.Reshape(l); if (this->data == nullptr) this->data = std::make_shared< DenseData >(); this->data->Resize(this->shape.Count()); is_full = true; row_idxes.Reshape({this->shape.dims[0]}); row_idxes.Reshape({0}); } template void TensorTemplate::ReshapeLike(RowSpTensor& src) { ASSERT(!is_full && row_idxes.shape.Count() == 0, "should be empty before reshaping"); ASSERT(data && data->mem_size >= src.shape.Count(), "should manually allocate memory before reshaping"); this->shape.Reshape(src.shape.dims); this->is_full = src.is_full; if (!src.is_full) row_idxes.CopyFrom(src.row_idxes); } template void TensorTemplate::RowSparseCopy(DTensor& src) { if (is_full) Full().CopyFrom(src); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; memcpy(data->ptr + row_idx * dim, src.data->ptr + row_idx * dim, sizeof(Dtype) * dim); }); } } template void TensorTemplate::Scale(Dtype scalar) { if (is_full) Full().Scale(scalar); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; Dtype* cur_ptr = data->ptr + row_idx * dim; for (size_t i = 0; i < dim; ++i) cur_ptr[i] *= scalar; }); } } template void TensorTemplate::Sqrt() { if (is_full) Full().Sqrt(); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Sqrt(dim, data->ptr + row_idx * dim, data->ptr + row_idx * dim); }); } } template void TensorTemplate::RowSparseAdd(Dtype scalar) { if (is_full) Full().Add(scalar); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; Dtype* cur_ptr = data->ptr + row_idx * dim; for (size_t i = 0; i < dim; ++i) cur_ptr[i] += scalar; }); } } template void TensorTemplate::RowSparseInv() { if (is_full) Full().Inv(); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Inv(dim, data->ptr + row_idx * dim, data->ptr + row_idx * dim); }); } } template void TensorTemplate::ElewiseMul(DTensor& src) { ASSERT(this->shape == src.shape, "shape doesn't match in ElewiseMul"); if (is_full) Full().ElewiseMul(src); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Mul(dim, src.data->ptr + dim * row_idx, this->data->ptr + dim * row_idx, this->data->ptr + dim * row_idx); }); } } template MatType TensorTemplate::GetMatType() { return MatType::row_sparse; } template MatMode TensorTemplate::GetMatMode() { return MatMode::cpu; } template DTensor TensorTemplate::Full() { is_full = true; return DTensor(this->shape, this->data->ptr); } template void TensorTemplate::RowSpZeros() { if (is_full) Full().Zeros(); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; memset(data->ptr + row_idx * dim, 0, sizeof(Dtype) * dim); }); } row_idxes.Reshape({0}); is_full = false; } template void TensorTemplate::FullZeros() { is_full = true; RowSpZeros(); } template void TensorTemplate::RowSparseFill(Dtype scalar) { throw std::logic_error(std::string("not implemented")); } template void TensorTemplate::InsertRowIdxes(size_t cnt, int* new_idxes) { // merge size_t cur_cnt = row_idxes.shape.Count(); idx_buf.Reshape({cnt + cur_cnt}); memcpy(idx_buf.data->ptr, row_idxes.data->ptr, sizeof(int) * cur_cnt); memcpy(idx_buf.data->ptr + cur_cnt, new_idxes, sizeof(int) * cnt); // unique std::sort(idx_buf.data->ptr, idx_buf.data->ptr + idx_buf.shape.Count()); auto last = std::unique(idx_buf.data->ptr, idx_buf.data->ptr + idx_buf.shape.Count()); cur_cnt = last - idx_buf.data->ptr; // copy row_idxes.Reshape({cur_cnt}); memcpy(row_idxes.data->ptr, idx_buf.data->ptr, sizeof(int) * cur_cnt); } template void TensorTemplate::SparseMM(SpTensor& a, DTensor& b, Trans transA, Trans transB, Dtype alpha, Dtype beta) { ASSERT(transA == Trans::T, "only for bp right now"); if (is_full) { Full().MM(a, b, transA, transB, alpha, beta); return; } InsertRowIdxes(a.data->nnz, a.data->col_idx); Full().MM(a, b, transA, transB, alpha, beta); is_full = false; } template void TensorTemplate::RowSparseAxpy(Dtype a, DTensor& x) { if (is_full) Full().Axpy(a, x); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Axpy(dim, a, x.data->ptr + row_idx * dim, data->ptr + row_idx * dim); }); } } template void TensorTemplate::RowSparseAxpby(Dtype a, DTensor& x, Dtype b) { if (is_full) Full().Axpby(a, x, b); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Axpby(dim, a, x.data->ptr + row_idx * dim, b, data->ptr + row_idx * dim); }); } } template Dtype TensorTemplate::Norm2() { if (is_full) return Full().Norm2(); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); Dtype total_norm = 0.0; //using namespace tbb::detail::d1; std::mutex ll; tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; auto norm = MKL_Norm2(dim, data->ptr + row_idx * dim); norm = norm * norm; ll.lock(); total_norm += norm; ll.unlock(); }); return sqrt(total_norm); } else return 0; } template void TensorTemplate::Square() { if (is_full) Full().Square(); else if (row_idxes.shape.Count()) { size_t row_cnt = row_idxes.shape.Count(); size_t dim = this->shape.Count(1); tbb::parallel_for(size_t(0), row_cnt, size_t(1), [&](size_t i){ size_t row_idx = row_idxes.data->ptr[i]; MKL_Square(dim, data->ptr + row_idx * dim, data->ptr + row_idx * dim); }); } } template class TensorTemplate; template class TensorTemplate; TensorTemplate::TensorTemplate() { } void TensorTemplate::Reshape(std::vector l) { } MatType TensorTemplate::GetMatType() { return MatType::row_sparse; } MatMode TensorTemplate::GetMatMode() { return MatMode::cpu; } template class TensorTemplate; }