Program Listing for File SubmodularFunction.h¶
↰ Return to documentation for file (src/function/SubmodularFunction.h
)
#ifndef EXEMCL_SUBM_FUNCTION_H
#define EXEMCL_SUBM_FUNCTION_H
#include <src/io/DataTypes.h>
#include <thread>
#include <utility>
namespace exemcl {
class SubmodularFunction {
public:
SubmodularFunction(int workerCount = -1) {
setWorkerCount(workerCount);
}
virtual double operator()(const MatrixX<double>& S) const = 0;
virtual double operator()(const MatrixX<double>& S) = 0;
virtual double operator()(const MatrixX<double>& S, VectorXRef<double> elem) const {
if (S.cols() == elem.size()) {
std::unique_ptr<MatrixX<double>> S_elem = std::make_unique<MatrixX<double>>(S);
S_elem->conservativeResize(S.rows() + 1, Eigen::NoChange_t());
S_elem->row(S.rows()) << elem.transpose();
// Call subsequent operators.
auto marginalGain = operator()(*S_elem) - operator()(S);
// Return value.
return marginalGain;
} else
throw std::runtime_error("SubmodularFunction::operator(): The number of columns in matrix `S` and the number of elements in vector `elem` do not match ("
+ std::to_string(S.cols()) + " vs. " + std::to_string(elem.size()) + ").");
}
virtual double operator()(const MatrixX<double>& S, VectorXRef<double> elem) {
return ((const SubmodularFunction*) (this))->operator()(S, elem);
}
virtual std::vector<double> operator()(const std::vector<MatrixX<double>>& S_multi) const {
// Construct vector for storing utilities.
std::vector<double> utilities;
utilities.resize(S_multi.size());
// Calculate utilities.
#pragma omp parallel for num_threads(_workerCount)
for (unsigned long i = 0; i < S_multi.size(); i++)
utilities[i] = operator()(S_multi[i]);
// Return value.
return utilities;
};
virtual std::vector<double> operator()(const std::vector<MatrixX<double>>& S_multi) {
return ((const SubmodularFunction*) (this))->operator()(S_multi);
};
virtual std::vector<double> operator()(const std::vector<MatrixX<double>>& S_multi, VectorXRef<double> elem) const {
auto S_multi_elem = std::make_unique<std::vector<MatrixX<double>>>();
S_multi_elem->reserve(S_multi.size());
// Create a new S_multi set, but include the marginal vector.
for (auto& S_elem : S_multi) {
S_multi_elem->push_back(S_elem);
S_multi_elem->back().conservativeResize(S_elem.rows() + 1, Eigen::NoChange_t());
S_multi_elem->back().row(S_elem.rows()) << elem.transpose();
}
// Evaluate S_multi_elem and S_multi.
auto utilityS_multi = operator()(S_multi);
auto utilityS_multi_elem = operator()(*S_multi_elem);
// Calculate the difference between the utilities of S_multi_elem and S_multi.
std::vector<double> marginalGains;
marginalGains.resize(S_multi.size());
for (unsigned long i = 0; i < utilityS_multi.size(); i++)
marginalGains[i] = utilityS_multi_elem[i] - utilityS_multi[i];
return marginalGains;
}
virtual std::vector<double> operator()(const std::vector<MatrixX<double>>& S_multi, VectorXRef<double> elem) {
return ((const SubmodularFunction*) (this))->operator()(S_multi, elem);
}
virtual std::vector<double> operator()(const MatrixX<double>& S, std::vector<VectorXRef<double>> elems) const {
// Create a vector, which will hold {S u e_1}, ..., {S u e_n}
auto S_elems = std::make_unique<std::vector<MatrixX<double>>>(elems.size(), S);
// Build {S u e_1}, ..., {S u e_n}.
for (unsigned int i = 0; i < elems.size(); i++) {
auto& elem = elems[i];
(*S_elems)[i].conservativeResize(S.rows() + 1, Eigen::NoChange_t());
(*S_elems)[i].row(S.rows()) << elem.transpose();
}
// Evaluate S.
auto S_funcValue = operator()(S);
// Evaluate all S with elems.
auto S_elems_funcValue = operator()(*S_elems);
// Create a result vector.
std::vector<double> gains;
gains.resize(elems.size());
// Fill the results.
for (unsigned int i = 0; i < elems.size(); i++)
gains[i] = S_elems_funcValue[i] - S_funcValue;
return gains;
}
virtual std::vector<double> operator()(const MatrixX<double>& S, std::vector<VectorXRef<double>> elems) {
return ((const SubmodularFunction*) (this))->operator()(std::move(S), std::move(elems));
}
virtual unsigned int getWorkerCount() const {
return _workerCount;
};
virtual void setWorkerCount(int workerCount) {
if (workerCount >= 1)
_workerCount = workerCount;
else {
auto suggestedThreads = std::thread::hardware_concurrency();
_workerCount = suggestedThreads > 0 ? suggestedThreads : 1;
}
}
virtual void setMemoryLimit(long memoryLimit) {
throw std::runtime_error("SubmodularFunction::setMemoryLimit: Not implemented.");
}
virtual ~SubmodularFunction() = default;
protected:
unsigned int _workerCount = 1;
};
}
#endif // EXEMCL_SUBM_FUNCTION_H