Keras custom layer with CNTK backend (CRF as RNN)

214 Views Asked by At

I am attempting to duplicate the CRF as RNN which has been implemented in Keras but uses TensorFlow as a backend (https://github.com/sadeepj/crfasrnn_keras). The Keras front-end is fine, but some of the backend code is written as a TensorFlow custom op. I am trying to duplicate this in CNTK, but I have a few questions.

import cntk as C
from cntk import ops
import copy
import numpy as np

C.device.try_set_default_device(C.device.cpu())

ops.register_native_user_function('HighDimFilterOp', 'Cntk.HighDimFilter-' + C.__version__.rstrip('+'), 'CreateHighDimFilter')

def high_dim_filter(image=None, rgb=None, **kwargs):
inputs = [list(image), list(rgb)];
layer_config = copy.deepcopy(kwargs)
ops.native_user_function('HighDimFilterOp', inputs,    layer_config, 'high_dim_filter')

This code is the Python call to my user C++ function. The C++ interface is as follows:

#include "HighDimFilter.h"

using namespace CNTK;

extern "C"
#ifdef _WIN32
    __declspec(dllexport)
#endif
Function* CreateHighDimFilter(const Variable* operands, size_t /*numOperands*/, const Dictionary* attributes, const wchar_t* name)
{
    printf("Creating HighDimFilter\n");
    return new HighDimFilter({operands[0], operands[1]}, *attributes, name);
}

and the custom function itself is defined as:

#pragma once
#include "CNTKLibrary.h"
#include "modified_permutohedral.h"

using namespace CNTK;

class HighDimFilter final : public Function
{
    bool _bilateral;
    float _theta_alpha;
    float _theta_beta;
    float _theta_gamma;

    enum Input : uint32_t
    {
        SCORES,
        IM_INFO
    };

public:
    HighDimFilter(const std::vector<Variable>& inputs, const Dictionary& attributes, const std::wstring& name = L"HighDimFilter")
        : Function(inputs, attributes, name)
    {
        if (attributes.Contains(L"bilateral"))
            _bilateral = attributes[L"bilateral"].Value<bool>();

        if (_bilateral == false)
        {
            if (attributes.Contains(L"theta_gamma"))
                _theta_gamma = static_cast<float>(attributes[L"theta_gamma"].Value<double>());
        }
        else
        {
            if (attributes.Contains(L"theta_alpha"))
                _theta_alpha = static_cast<float>(attributes[L"theta_alpha"].Value<double>());

            if (attributes.Contains(L"theta_beta"))
                _theta_beta = static_cast<float>(attributes[L"theta_beta"].Value<double>());
        }
    }

private:
    void _compute_spatial_kernel(NDArrayViewPtr& Tensor, const float theta_gamma)
    {
        auto output_kernel = Tensor->WritableDataBuffer<float>();
        auto outputShape = Tensor->Shape();
        //auto channels = outputShape[0];
        auto height = outputShape[1];
        auto width = outputShape[2];
        const auto num_pixels = width * height;
        for (int p = 0; p < num_pixels; ++p)
        {
            output_kernel[2 * p] = static_cast<float>(p % width) / theta_gamma;
            output_kernel[2 * p + 1] = static_cast<float>(p / width) / theta_gamma;
        }
    }

    void _compute_bilateral_kernel(NDArrayViewPtr& Tensor, const NDArrayViewPtr& Image,
                                   const float theta_alpha, const float theta_beta)
    {
        auto output_kernel = Tensor->WritableDataBuffer<float>();
        auto rgb = Image->DataBuffer<float>();
        auto outputShape = Tensor->Shape();
        //auto channels = outputShape[0];
        auto height = outputShape[1];
        auto width = outputShape[2];
        const auto num_pixels = height * width;

        for (int p = 0; p < num_pixels; ++p)
        {
            // Spatial terms
            output_kernel[5 * p] = static_cast<float>(p % width) / theta_alpha;
            output_kernel[5 * p + 1] = static_cast<float>(p / width) / theta_alpha;

            // Color terms
            output_kernel[5 * p + 2] = static_cast<float>(rgb[p] / theta_beta);
            output_kernel[5 * p + 3] = static_cast<float>(rgb[num_pixels + p] / theta_beta);
            output_kernel[5 * p + 4] = static_cast<float>(rgb[2 * num_pixels + p] / theta_beta);
        }
    }

    BackPropStatePtr Forward(const std::vector<ValuePtr>& inputValues,
                             std::unordered_map<Variable, ValuePtr>& outputs,
                             const DeviceDescriptor& computeDevice,
                             const std::unordered_set<Variable>& /*outputsToRetainBackwardStateFor */) override
    {
#if 0
        auto scoresShape = inputValues[Input::SCORES]->Shape();
        auto channels = scoresShape[0];
        auto height = scoresShape[1];
        auto width = scoresShape[2];
        const auto num_pixels = width * height;
        auto &outputValue = outputs[this->Output()];
        if (outputValue == nullptr)
        {
            outputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, scoresShape, computeDevice));
        }

        if (computeDevice.Type() != DeviceKind::CPU) 
            throw std::runtime_error("HighDimFilter: only CPU evaluation is supported at the moment.");

        ModifiedPermutohedral mp;

        if (_bilateral)
        {
            auto &kernel_vals = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, NDShape({5, height, width}), computeDevice));
            //float* kernel_vals = new float[5 * num_pixels];
            _compute_bilateral_kernel(kernel_vals->Data(), inputValues[Input::IM_INFO]->Data(),
                                     _theta_alpha, _theta_beta);
            mp.init(kernel_vals->Data(), 5, num_pixels);
            mp.compute(outputValue->Data(), inputValues[Input::SCORES]->Data(), false);
        }
        else
        {
            auto &kernel_vals = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, NDShape({2, height, width}), computeDevice));
            _compute_spatial_kernel(kernel_vals->Data(), _theta_gamma);
            mp.init(kernel_vals->Data(), 2, num_pixels);
            mp.compute(outputValue->Data(), inputValues[Input::SCORES]->Data(), channels, false);
        }

        return MakeSharedObject<BackPropState>(this->shared_from_this(), computeDevice, std::unordered_map<Variable, ValuePtr>({ {Inputs()[Input::IM_INFO], inputValues[Input::IM_INFO]} }));
#else
        return nullptr;
#endif
    }

    void Backward(const BackPropStatePtr& state,
                  const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
                  std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) override
    {
#if 0
        auto gradOutputVariable = Inputs()[Input::SCORES];
        auto inputVariable = Inputs()[Input::IM_INFO];
        auto &gradValue = backPropagatedGradientValuesForInputs[gradOutputVariable];
        if (gradValue == nullptr)
            gradValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, gradOutputVariable.Shape(), state->Device()));

        auto imageData = state->SavedForwardPropValues().at(inputVariable)->Data();
        auto imageShape = imageData->Shape();
        auto channels = imageShape[0];
        auto height = imageShape[1];
        auto width = imageShape[2];
        const auto num_pixels = width * height;

        if (state->Device().Type() != DeviceKind::CPU)
            throw std::runtime_error("HighDimFilter: only CPU evaluation is supported at the moment.");

        auto rootGradientData = rootGradientValues.at(this->Output())->Data();

        ModifiedPermutohedral mp;

        if (_bilateral)
        {
            auto &kernel_vals = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, NDShape({5, height, width}), state->Device()));
            //float* kernel_vals = new float[5 * num_pixels];
            _compute_bilateral_kernel(kernel_vals->Data(), imageData,
                                      _theta_alpha, _theta_beta);
            mp.init(kernel_vals->Data(), 5, num_pixels);
            mp.compute(gradValue->Data(), rootGradientData, true);
        }
        else
        {
            auto &kernel_vals = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, NDShape({2, height, width}), state->Device()));
            _compute_spatial_kernel(kernel_vals->Data(), _theta_gamma);
            mp.init(kernel_vals->Data(), 2, num_pixels);
            mp.compute(gradValue->Data(), rootGradientData, channels, true);
        }
#endif
        return;
    }

    const std::wstring& OpName() const override
    {
        static const std::wstring opName = L"HighDimFilterOp";
        return opName;
    }

    size_t CurrentVersion() const override
    {
        NOT_IMPLEMENTED;
    }

    void InferOutputs(std::vector<Variable>& /*outputs */) override
    {
        NOT_IMPLEMENTED;
    }

    FunctionPtr Clone(const std::vector<Variable>& /*clonedInputs */) override
    {
        return nullptr;
    }
};

My python call looks like:

bilateral_high_dim_filter = custom_module.high_dim_filter(image=all_ones_flat,
                                                          rgb=rgb_flat,
                                                          bilateral=True,
                                                          theta_alpha=self.theta_alpha,
                                                          theta_beta=self.theta_beta)


high_dim_filter = custom_module.high_dim_filter(image=all_ones_flat,
                                                rgb=rgb_flat,
                                                bilateral=False,
                                                theta_gamma=self.theta_gamma)

The questions are as follows: 1) What are the "operands" passed in to the native_user_function on initialization? Are these only passed on initialization (are they intended to be weight and bias initialization)? How are the input operands used in the "Function" construction initializer? If I set these to "None" in Python, the code crashes. 2) How do you forward propagate the filter? Just call "forward()"? What about the required arguments to forward propagate? 3) Is there a numerical gradient calculation in CNTK similar to TensorFlow to check the gradient?

0

There are 0 best solutions below