class gpr::laplace::SteinVariational

Overview

#include <SteinVariational.h>
 
class SteinVariational {
public:
    // typedefs
 
    typedef std::function<Eigen::VectorXd(const Eigen::VectorXd&theta)> GradOracle;
 
    // construction
 
    SteinVariational();
 
    // methods
 
    void initFromCCD(const Eigen::VectorXd& theta_star, const Eigen::MatrixXd& L, double f0 = 0.0);
    int fit(GradOracle grad_oracle, int T = 30, double eps = 0.05, double conv_tol = 1e-3);
    int fit(GradOracle grad_oracle, const SteinVariationalConfig& cfg);
    const std::vector<Eigen::VectorXd>& particles() const;
    std::size_t numParticles() const;
    Eigen::VectorXd uniformWeights() const;
    bool fitted() const;
    void invalidate();
};

Detailed Documentation

Typedefs

typedef std::function<Eigen::VectorXd(const Eigen::VectorXd&theta)> GradOracle

Type of the negative-log-marginal-likelihood gradient oracle. Caller provides a callable that returns grad_log_p(theta) = -grad_NLML(theta) + grad_log_prior(theta) at a given theta.

Methods

void initFromCCD(const Eigen::VectorXd& theta_star, const Eigen::MatrixXd& L, double f0 = 0.0)

Initialise K particles from a CCD-style grid centred at theta_star with stretch L (theta_k = theta_star + L * z_k). theta_star: p-dim MAP point. L: p x p z-to-theta stretch matrix (typically U * Lambda^{-1/2} from the LaplaceINLA decomposition; provides a sensible initial spread that the SVGD updates will refine). f0: CCD radius (default Rue 2009: sqrt(p + 2)).

int fit(GradOracle grad_oracle, int T = 30, double eps = 0.05, double conv_tol = 1e-3)

Run T SVGD iterations using the given gradient oracle. step_size eps: 0.01-0.1 typical (Liu & Wang 2016 use 0.1 with AdaGrad scaling; we use a fixed step for simplicity since our gradient magnitudes are already O(1) post-RBF). Returns the iteration count actually performed (<=T if converged early).

int fit(GradOracle grad_oracle, const SteinVariationalConfig& cfg)

Config-struct overload of fit. Forwards to the primitive fit(oracle, T, eps, tol) above; new code should prefer this form so the hardcoded literals don’t proliferate at call sites and so future knobs can be added without breaking the caller API.

const std::vector<Eigen::VectorXd>& particles() const

Particle access (post-fit).

Eigen::VectorXd uniformWeights() const

Uniform mixture weights (1/K each); included for API parity with LaplaceINLA ‘s gridWeights() which returns Rue-weighted CCD points.