#pragma once #include "daal.h" #include "logistic_loss_dense_default_batch_kernel.h" #include "logistic_loss_dense_default_batch_impl.i" //#include "logistic_loss_dense_default_batch_container.h" //#include "algorithm_container_base_common.h" using namespace daal::algorithms::optimization_solver; namespace logistic_prox_function { struct Parameter : public daal::algorithms::optimization_solver::logistic_loss::Parameter //ADDED { typedef daal::algorithms::optimization_solver::logistic_loss::Parameter super; Parameter(size_t numberOfTerms, daal::data_management::NumericTablePtr batchIndices = daal::data_management::NumericTablePtr(), const DAAL_UINT64 resultsToCompute = daal::algorithms::optimization_solver::objective_function::gradient) : super(numberOfTerms, batchIndices, resultsToCompute) { test = 0; } Parameter(const Parameter& other) : super(other) { test = other.test; } /** * Checks the correctness of the parameter * \return Status of computations */ virtual daal::services::Status check() const DAAL_C11_OVERRIDE { return super::check(); } float test; }; template class BatchContainer : public logistic_loss::BatchContainer { public: BatchContainer(daal::services::Environment::env* daalEnv) : logistic_loss::BatchContainer(daalEnv) {}; /** Default destructor */ virtual ~BatchContainer(){}; virtual daal::services::Status compute() DAAL_C11_OVERRIDE; //static float lipConst; }; template daal::services::Status BatchContainer::compute() { logistic_loss::Input *input = static_cast(this->_in); objective_function::Result *result = static_cast(this->_res); Parameter *parameter = static_cast(this->_par); // CHANGED daal::data_management::NumericTable *value = nullptr; daal::data_management::NumericTable *hessian = nullptr; daal::data_management::NumericTable *gradient = nullptr; daal::data_management::NumericTable *nonSmoothTermValue = nullptr; daal::data_management::NumericTable *proximalProjection = nullptr; daal::data_management::NumericTable *lipschitzConstant = nullptr; if(parameter->resultsToCompute & objective_function::value) value = result->get(objective_function::valueIdx).get(); if(parameter->resultsToCompute & objective_function::hessian) hessian = result->get(objective_function::hessianIdx).get(); if(parameter->resultsToCompute & objective_function::gradient) gradient = result->get(objective_function::gradientIdx).get(); if(parameter->resultsToCompute & objective_function::nonSmoothTermValue) { nonSmoothTermValue = result->get(objective_function::nonSmoothTermValueIdx).get(); } if(parameter->resultsToCompute & objective_function::proximalProjection) { proximalProjection = result->get(objective_function::proximalProjectionIdx).get(); } if(parameter->resultsToCompute & objective_function::lipschitzConstant) { lipschitzConstant = result->get(objective_function::lipschitzConstantIdx).get(); } if (proximalProjection) { /* *do here proximal projection with taking into account that saga solver will devide argument to handle step size in proximal projection *(see implementation in logistic loss kernel as example!) * */ algorithmFPType* b; HomogenNumericTable* hmgBeta = dynamic_cast*>(input->get(logistic_loss::argument).get()); b = hmgBeta->getArray(); algorithmFPType* prox; HomogenNumericTable* hmgProx = dynamic_cast*>(proximalProjection); prox = hmgProx->getArray(); int nBeta = proximalProjection->getNumberOfRows(); for (int i = 0; i < nBeta; i++) { prox[i] = b[i]; } //if (b[0] <= b[1]) //{ // prox[0] = (b[0] + b[1])/2.0; // prox[1] = (b[0] + b[1])/2.0; //} if (b[0] <= parameter->test) //CHANGED { prox[0] = parameter->test; //CHANGED } return daal::services::Status(); } else { logistic_loss::internal::LogLossKernel* defaultKernel = new logistic_loss::internal::LogLossKernel(); daal::services::Status status = defaultKernel->compute(input->get(logistic_loss::data).get(), input->get(logistic_loss::dependentVariables).get(), input->get(logistic_loss::argument).get(), value, hessian, gradient, nonSmoothTermValue, proximalProjection, lipschitzConstant, parameter); return status; } } template class Batch : public logistic_loss::Batch { public: /** * Main constructor */ typedef daal::algorithms::optimization_solver::logistic_loss::Batch super; Parameter par; //ADDED Parameter& parameter() { return *static_cast(_par); } //ADDED /** * Gets parameter of the algorithm * \return parameter of the algorithm */ const Parameter& parameter() const { return *static_cast(_par); } //ADDED Batch(size_t numberOfTerms): super(numberOfTerms), par(numberOfTerms) //CHANGED { initialize(); } virtual ~Batch() {} Batch(const Batch& other): super(other), par(other.par) //CHANGED { initialize(); } daal::services::SharedPtr > clone() const { return services::SharedPtr >(cloneImpl()); } daal::services::Status allocate() { return allocateResult(); } static daal::services::SharedPtr > create(size_t numberOfTerms) { return logistic_loss::Batch::create(numberOfTerms) }; protected: virtual Batch* cloneImpl() const DAAL_C11_OVERRIDE { return new Batch(*this); } void initialize() { daal::algorithms::Analysis::_ac = new BatchContainer(&(this->_env)); _par = ∥ //ADDED } }; }