Class: MLKEM::Math::Sampling
- Inherits:
-
Object
- Object
- MLKEM::Math::Sampling
- Defined in:
- lib/ml_kem/math/sampling.rb
Overview
Implements sampling algorithms used in ML-KEM (Kyber), including NTT-domain sampling and centered binomial distribution (CBD).
These routines are used to generate polynomial coefficients from uniformly random byte arrays or XOF outputs.
Instance Method Summary collapse
-
#initialize(q = Constants::Q) ⇒ Sampling
constructor
Initializes the Sampling object with a modulus ‘q`.
-
#sample_ntt(b) ⇒ Array<Integer>
Samples a polynomial in the NTT domain from a byte string using SHAKE128.
-
#sample_poly_cbd(eta, b) ⇒ Array<Integer>
Samples a polynomial using centered binomial distribution (CBD).
Constructor Details
#initialize(q = Constants::Q) ⇒ Sampling
Initializes the Sampling object with a modulus ‘q`.
21 22 23 |
# File 'lib/ml_kem/math/sampling.rb', line 21 def initialize(q = Constants::Q) @q = q end |
Instance Method Details
#sample_ntt(b) ⇒ Array<Integer>
Samples a polynomial in the NTT domain from a byte string using SHAKE128.
Implements Algorithm 7, SampleNTT(B).
This uses rejection sampling to select values uniformly < q from SHAKE128 output.
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# File 'lib/ml_kem/math/sampling.rb', line 36 def sample_ntt(b) xof_data = Crypto::HashFunctions.shake128(b, 1024) j = 0 a = [] i = 0 while j < 256 && i < xof_data.length - 2 c = xof_data.bytes[i, 3] d1 = c[0] + 256 * (c[1] % 16) d2 = (c[1] / 16) + 16 * c[2] if d1 < @q a << d1 j += 1 end if d2 < @q && j < 256 a << d2 j += 1 end i += 3 end a end |
#sample_poly_cbd(eta, b) ⇒ Array<Integer>
Samples a polynomial using centered binomial distribution (CBD).
Implements Algorithm 8, SamplePolyCBD_eta(B).
The result is a noise polynomial used in ML-KEM.
75 76 77 78 79 80 81 82 83 84 85 86 |
# File 'lib/ml_kem/math/sampling.rb', line 75 def sample_poly_cbd(eta, b) b_bits = Math::ByteOperations.bytes_to_bits(b) f = Array.new(256, 0) 256.times do |i| x = b_bits[(2 * i * eta)...((2 * i + 1) * eta)].sum y = b_bits[((2 * i + 1) * eta)...((2 * i + 2) * eta)].sum f[i] = (x - y) % @q end f end |