Program Listing for File ExemplarClusteringSubmodularFunction.h

Return to documentation for file (src/function/cpu/ExemplarClusteringSubmodularFunction.h)

#ifndef EXEMCL_FUNCTION_CPU
#define EXEMCL_FUNCTION_CPU

#include <src/function/SubmodularFunction.h>
#include <utility>

namespace exemcl::cpu {
    template<typename HostDataType = float>
    class ExemplarClusteringSubmodularFunction : public SubmodularFunction {
    public:
        using SubmodularFunction::operator();

        explicit ExemplarClusteringSubmodularFunction(const MatrixX<HostDataType>& V, int workerCount = -1) :
            SubmodularFunction(workerCount), _V(std::make_unique<MatrixX<HostDataType>>(V)) {
            MatrixX<HostDataType> zeroVec = VectorX<HostDataType>::Zero(_V->cols()).transpose();
            _zeroVecValue = L(zeroVec);
        };

        double operator()(const MatrixX<double>& S) override {
            return ((const ExemplarClusteringSubmodularFunction*) (this))->operator()(S);
        };

        double operator()(const MatrixX<double>& S) const override {
            auto S_copy = std::make_unique<MatrixX<HostDataType>>(S.cast<HostDataType>());

            // Add zero vector to data copy.
            S_copy->conservativeResize(S_copy->rows() + 1, Eigen::NoChange_t());
            S_copy->row(S_copy->rows() - 1).setZero();

            // Make calculations.
            HostDataType L_2 = L(*S_copy);

            return _zeroVecValue - L_2;
        };

        const MatrixX<HostDataType>& getV() const {
            return _V;
        };

    private:
        HostDataType _zeroVecValue;
        const std::unique_ptr<MatrixX<HostDataType>> _V;

        HostDataType L(const MatrixX<HostDataType>& S_inner) const {
            auto* accuArray = new HostDataType[_V->rows()];

            for (unsigned int i = 0; i < _V->rows(); i++) {
                auto min_val = std::numeric_limits<HostDataType>::max();
                for (unsigned int j = 0; j < S_inner.rows(); j++)
                    min_val = std::min((_V->row(i) - S_inner.row(j)).squaredNorm(), min_val);
                accuArray[i] = min_val;
            }

            HostDataType accu = 0.0;
#pragma omp simd reduction(+ : accu)
            for (unsigned int i = 0; i < _V->rows(); i++)
                accu += accuArray[i];

            delete[] accuArray;
            return accu / static_cast<HostDataType>(_V->rows());
        };
    };
}

#endif // EXEMCL_FUNCTION_CPU