#pragma once #include "daal.h" #include "logistic_loss_dense_default_batch_kernel.h" //#include "algorithm_container_base_common.h" using namespace daal::algorithms::optimization_solver; namespace logistic_prox_function { template class BatchContainer : public logistic_loss::BatchContainer { public: BatchContainer(daal::services::Environment::env* daalEnv) : logistic_loss::BatchContainer(daalEnv) {}; /** Default destructor */ virtual ~BatchContainer(){}; virtual daal::services::Status compute() DAAL_C11_OVERRIDE; }; template daal::services::Status BatchContainer::compute() { logistic_loss::Input *input = static_cast(this->_in); objective_function::Result *result = static_cast(this->_res); logistic_loss::Parameter *parameter = static_cast(this->_par); daal::data_management::NumericTable *value = nullptr; daal::data_management::NumericTable *hessian = nullptr; daal::data_management::NumericTable *gradient = nullptr; daal::data_management::NumericTable *nonSmoothTermValue = nullptr; daal::data_management::NumericTable *proximalProjection = nullptr; daal::data_management::NumericTable *lipschitzConstant = nullptr; if(parameter->resultsToCompute & objective_function::value) value = result->get(objective_function::valueIdx).get(); if(parameter->resultsToCompute & objective_function::hessian) hessian = result->get(objective_function::hessianIdx).get(); if(parameter->resultsToCompute & objective_function::gradient) gradient = result->get(objective_function::gradientIdx).get(); if(parameter->resultsToCompute & objective_function::nonSmoothTermValue) { nonSmoothTermValue = result->get(objective_function::nonSmoothTermValueIdx).get(); } if(parameter->resultsToCompute & objective_function::proximalProjection) { proximalProjection = result->get(objective_function::proximalProjectionIdx).get(); } if(parameter->resultsToCompute & objective_function::lipschitzConstant) { lipschitzConstant = result->get(objective_function::lipschitzConstantIdx).get(); } if(proximalProjection) { /* *do here proximal projection with taking into account that saga solver will devide argument to handle step size in proximal projection *(see implementation in logistic loss kernel as example!) * */ } else { return defaultKernel->compute(input->get(logistic_loss::data).get(), input->get(logistic_loss::dependentVariables).get(), input->get(logistic_loss::argument).get(), value, hessian, gradient, nonSmoothTermValue, proximalProjection, lipschitzConstant, parameter); } } template class Batch : public logistic_loss::Batch { public: /** * Main constructor */ Batch(size_t numberOfTerms): logistic_loss::Batch(numberOfTerms) { initialize(); } virtual ~Batch() {} Batch(const Batch& other): logistic_loss::Batch(other) { initialize(); } daal::services::SharedPtr > clone() const { return services::SharedPtr >(cloneImpl()); } daal::services::Status allocate() { return allocateResult(); } static daal::services::SharedPtr > create(size_t numberOfTerms) { return logistic_loss::Batch::create(numberOfTerms) }; protected: virtual Batch* cloneImpl() const DAAL_C11_OVERRIDE { return new Batch(*this); } void initialize() { daal::algorithms::Analysis::_ac = new BatchContainer(&(this->_env)); } }; }