#pragma once #include "daal.h" #include "algorithms/optimization_solver/objective_function/logistic_loss_types.h" #include "algorithms/optimization_solver/objective_function/logistic_loss_batch.h" using namespace daal; using namespace daal::algorithms::optimization_solver; namespace logistic_prox_function { template class BatchContainer : public logistic_loss::BatchContainer { public: BatchContainer(daal::services::Environment::env* daalEnv) :logistic_loss::BatchContainer(daalEnv) {}; /** Default destructor */ virtual ~BatchContainer() {}; virtual services::Status compute() DAAL_C11_OVERRIDE; }; template daal::services::Status BatchContainer::compute() { return daal::services::Status(); } template class Batch : public logistic_loss::Batch { public: typedef logistic_loss::Batch super; typedef algorithms::optimization_solver::logistic_loss::Input InputType; typedef algorithms::optimization_solver::logistic_loss::Parameter ParameterType; typedef typename super::ResultType ResultType; /** * Main constructor */ Batch(size_t numberOfTerms) : logistic_loss::Batch(numberOfTerms) { initialize(); } virtual ~Batch() {} Batch(const Batch& other) : logistic_loss::Batch(other) { initialize(); } virtual int getMethod() const DAAL_C11_OVERRIDE { return(int)method; } services::SharedPtr > clone() const { return services::SharedPtr >(cloneImpl()); } services::Status allocate() { return allocateResult(); } ParameterType& parameter() { return *static_cast(_par); } const ParameterType& parameter() const { return *static_cast(_par); } //static services::SharedPtr > create(size_t numberOfTerms); protected: //virtual Batch* cloneImpl() const DAAL_C11_OVERRIDE //{ // return new Batch(*this); //} //virtual services::Status allocateResult() DAAL_C11_OVERRIDE //{ // services::Status s = _result->allocate(&input, _par, (int)method); // _res = _result.get(); // return s; //} void initialize() { //Analysis::_ac = new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env); Analysis::_ac = new BatchContainer(&_env); _in = &input; _par = sumOfFunctionsParameter; } public: //InputType input; /*!< %Input data structure */ }; /** @} */ }