#ifndef CUSTOM_LAYER_H__ #define CUSTOM_LAYER_H__ #include using namespace InferenceEngine; // A CustomLayerImpl class is an example implementation class __declspec(dllexport) CustomLayerImpl : public ILayerExecImpl { public: explicit CustomLayerImpl(const CNNLayer *layer) : cnnLayer(new CNNLayer(*layer)) {} virtual StatusCode getSupportedConfigurations(std::vector& conf, ResponseDesc *resp) noexcept override; virtual StatusCode init(LayerConfig& config, ResponseDesc *resp) noexcept override; virtual StatusCode execute(std::vector& inputs, std::vector& outputs, ResponseDesc *resp) noexcept override; private: CNNLayer * cnnLayer; }; class __declspec(dllexport) CustomLayerFactory : public InferenceEngine::ILayerImplFactory { public: explicit CustomLayerFactory(const CNNLayer *layer) : cnnLayer(new CNNLayer(*layer)) {} virtual ~CustomLayerFactory() {} private: CNNLayer * cnnLayer; public: // ... constructor and destructor StatusCode getShapes(const std::vector& inShapes, std::vector& outShapes, ResponseDesc *resp) noexcept override { if (cnnLayer == nullptr) { std::string errorMsg = "Cannot get cnn layer!"; errorMsg.copy(resp->msg, sizeof(resp->msg) - 1); return GENERAL_ERROR; } if (inShapes.size() != 1) { std::string errorMsg = "Incorrect input shapes!"; errorMsg.copy(resp->msg, sizeof(resp->msg) - 1); return GENERAL_ERROR; } outShapes.clear(); outShapes.emplace_back(inShapes[0]); return OK; } StatusCode getImplementations(std::vector& impls, ResponseDesc *resp) noexcept override { // Yoy can put cnnLayer to implimentation if it is necessary. impls.push_back(std::make_shared(CustomLayerImpl(cnnLayer))); return OK; } }; class __declspec(dllexport) CustomExtention : public InferenceEngine::IExtension { public: // could be used to cleanup resources void Unload() noexcept override { } // is used when destruction happens void Release() noexcept override { delete this; } // logging is used to track what is going on inside void SetLogCallback(InferenceEngine::IErrorListener &listener) noexcept override {} private: static Version ExtensionDescription; public: // gets extension version information void GetVersion(const InferenceEngine::Version *& versionInfo) const noexcept override { versionInfo = &ExtensionDescription; } StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override { std::string type_name = "AbsVal"; types = new char *[1]; size = 1; types[0] = new char[type_name.size() + 1]; std::copy(type_name.begin(), type_name.end(), types[0]); types[0][type_name.size()] = '\0'; return OK; } StatusCode getFactoryFor(ILayerImplFactory* &factory, const CNNLayer *_cnnLayer, ResponseDesc *resp) noexcept override { if (_cnnLayer->type != "AbsVal") { std::string errorMsg = std::string("Factory for ") + _cnnLayer->type + " wasn't found!"; errorMsg.copy(resp->msg, sizeof(resp->msg) - 1); return NOT_FOUND; } factory = new CustomLayerFactory(_cnnLayer); return OK; } }; #endif