Class: DSPy::Optimizers::GaussianProcess

Inherits:
Object
  • Object
show all
Extended by:
T::Sig
Defined in:
lib/dspy/optimizers/gaussian_process.rb

Overview

Gaussian Process regression backed by Numo::TinyLinalg for Bayesian optimization.

Instance Method Summary collapse

Constructor Details

#initialize(length_scale: 1.0, signal_variance: 1.0, noise_variance: 1e-6) ⇒ GaussianProcess

Returns a new instance of GaussianProcess.



15
16
17
18
19
20
# File 'lib/dspy/optimizers/gaussian_process.rb', line 15

def initialize(length_scale: 1.0, signal_variance: 1.0, noise_variance: 1e-6)
  @length_scale = length_scale
  @signal_variance = signal_variance
  @noise_variance = noise_variance
  @fitted = T.let(false, T::Boolean)
end

Instance Method Details

#fit(x_train, y_train) ⇒ Object



23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/dspy/optimizers/gaussian_process.rb', line 23

def fit(x_train, y_train)
  x_matrix = to_matrix(x_train)
  y_vector = to_vector(y_train)

  kernel_matrix = rbf_kernel(x_matrix, x_matrix)
  kernel_matrix += Numo::DFloat.eye(kernel_matrix.shape[0]) * @noise_variance

  @cholesky_factor = Numo::TinyLinalg.cholesky(kernel_matrix, uplo: 'L')
  @alpha = Numo::TinyLinalg.cho_solve(@cholesky_factor, y_vector, uplo: 'L')

  @x_train = x_matrix
  @y_train = y_vector
  @fitted = true
end

#predict(x_test, return_std: false) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# File 'lib/dspy/optimizers/gaussian_process.rb', line 42

def predict(x_test, return_std: false)
  raise 'Gaussian Process not fitted' unless @fitted

  test_matrix = to_matrix(x_test)
  k_star = rbf_kernel(T.must(@x_train), test_matrix)

  mean = k_star.transpose.dot(T.must(@alpha))
  return mean unless return_std

  v = Numo::TinyLinalg.cho_solve(T.must(@cholesky_factor), k_star, uplo: 'L')
  k_star_star = rbf_kernel(test_matrix, test_matrix)
  covariance = k_star_star - k_star.transpose.dot(v)

  variance = covariance.diagonal.dup
  variance[variance < 1e-12] = 1e-12
  std = Numo::NMath.sqrt(variance)

  [mean, std]
end