|
//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM) //! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft). //! //! The namespace `d00` refers to the version currently implemented, in accordance with the CFRG draft. //! The `nist` namespace refers to the FIPS-203 publication. //! //! Quoting from the CFRG I-D: //! //! Kyber is not a Diffie-Hellman (DH) style non-interactive key //! agreement, but instead, Kyber is a Key Encapsulation Method (KEM). //! In essence, a KEM is a Public-Key Encryption (PKE) scheme where the //! plaintext cannot be specified, but is generated as a random key as //! part of the encryption. A KEM can be transformed into an unrestricted //! PKE using HPKE (RFC9180). On its own, a KEM can be used as a key //! agreement method in TLS. //! //! Kyber is an IND-CCA2 secure KEM. It is constructed by applying a //! Fujisaki--Okamato style transformation on InnerPKE, which is the //! underlying IND-CPA secure Public Key Encryption scheme. We cannot //! use InnerPKE directly, as its ciphertexts are malleable. //! //! ``` //! F.O. transform //! InnerPKE ----------------------> Kyber //! IND-CPA IND-CCA2 //! ``` //! //! Kyber is a lattice-based scheme. More precisely, its security is //! based on the learning-with-errors-and-rounding problem in module //! lattices (MLWER). The underlying polynomial ring R (defined in //! Section 5) is chosen such that multiplication is very fast using the //! number theoretic transform (NTT, see Section 5.1.3). //! //! An InnerPKE private key is a vector _s_ over R of length k which is //! _small_ in a particular way. Here k is a security parameter akin to //! the size of a prime modulus. For Kyber512, which targets AES-128's //! security level, the value of k is 2. //! //! The public key consists of two values: //! //! * _A_ a uniformly sampled k by k matrix over R _and_ //! //! * _t = A s + e_, where e is a suitably small masking vector. //! //! Distinguishing between such A s + e and a uniformly sampled t is the //! module learning-with-errors (MLWE) problem. If that is hard, then it //! is also hard to recover the private key from the public key as that //! would allow you to distinguish between those two. //! //! To save space in the public key, A is recomputed deterministically //! from a seed _rho_. //! //! A ciphertext for a message m under this public key is a pair (c_1, //! c_2) computed roughly as follows: //! //! c_1 = Compress(A^T r + e_1, d_u) //! c_2 = Compress(t^T r + e_2 + Decompress(m, 1), d_v) //! //! where //! //! * e_1, e_2 and r are small blinds; //! //! * Compress(-, d) removes some information, leaving d bits per //! coefficient and Decompress is such that Compress after Decompress //! does nothing and //! //! * d_u, d_v are scheme parameters. //! //! Distinguishing such a ciphertext and uniformly sampled (c_1, c_2) is //! an example of the full MLWER problem, see section 4.4 of [KyberV302]. //! //! To decrypt the ciphertext, one computes //! //! m = Compress(Decompress(c_2, d_v) - s^T Decompress(c_1, d_u), 1). //! //! It it not straight-forward to see that this formula is correct. In //! fact, there is negligible but non-zero probability that a ciphertext //! does not decrypt correctly given by the DFP column in Table 4. This //! failure probability can be computed by a careful automated analysis //! of the probabilities involved, see kyber_failure.py of [SecEst]. //! //! [KyberV302](https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf) //! [I-D](https://github.com/bwesterb/draft-schwabe-cfrg-kyber) //! [SecEst](https://github.com/pq-crystals/security-estimates) |
d00Length (in bytes) of a shared secret. |
// TODO // // - The bottleneck in Kyber are the various hash/xof calls: // - Optimize Zig's keccak implementation. // - Use SIMD to compute keccak in parallel. // - Can we track bounds of coefficients using comptime types without // duplicating code? // - Would be neater to have tests closer to the thing under test. // - When generating a keypair, we have a copy of the inner public key with // its large matrix A in both the public key and the private key. In Go we // can just have a pointer in the private key to the public key, but // how do we do this elegantly in Zig? |
Kyber512Length (in bytes) of a seed for deterministic encapsulation. |
const std = @import("std"); const builtin = @import("builtin"); |
Kyber768Length (in bytes) of a seed for key generation. |
const testing = std.testing; const assert = std.debug.assert; const crypto = std.crypto; const errors = std.crypto.errors; const math = std.math; const mem = std.mem; const RndGen = std.Random.DefaultPrng; const sha3 = crypto.hash.sha3; |
Kyber1024Algorithm name. |
// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1. const Q: i16 = 3329; |
nistA shared secret, and an encapsulated (encrypted) representation of it. |
// Montgomery R const R: i32 = 1 << 16; |
MLKem512A Kyber public key. |
// Parameter n, degree of polynomials. const N: usize = 256; |
MLKem768Size of a serialized representation of the key, in bytes. |
// Size of "small" vectors used in encryption blinds. const eta2: u8 = 2; |
MLKem1024Generates a shared secret, and encapsulates it for the public key.
If |
const Params = struct { name: []const u8, |
ciphertext_lengthSerializes the key into a byte array. |
// NIST ML-KEM variant instead of Kyber as originally submitted. ml_kem: bool = false, |
shared_lengthDeserializes the key from a byte array. |
// Width and height of the matrix A. k: u8, |
encaps_seed_lengthA Kyber secret key. |
// Size of "small" vectors used in private key and encryption blinds. eta1: u8, |
seed_length:Size of a serialized representation of the key, in bytes. |
// How many bits to retain of u, the private-key independent part // of the ciphertext. du: u8, |
nameDecapsulates the shared secret within ct using the private key. |
// How many bits to retain of v, the private-key dependent part // of the ciphertext. dv: u8, }; |
EncapsulatedSecretSerializes the key into a byte array. |
pub const d00 = struct { pub const Kyber512 = Kyber(.{ .name = "Kyber512", .k = 2, .eta1 = 3, .du = 10, .dv = 4, }); |
PublicKeyDeserializes the key from a byte array. |
pub const Kyber768 = Kyber(.{ .name = "Kyber768", .k = 3, .eta1 = 2, .du = 10, .dv = 4, }); |
bytes_lengthA Kyber key pair. |
pub const Kyber1024 = Kyber(.{ .name = "Kyber1024", .k = 4, .eta1 = 2, .du = 11, .dv = 5, }); }; |
encaps()Deterministically derive a key pair from a cryptograpically secure secret seed.
Except in tests, applications should generally call |
pub const nist = struct { pub const MLKem512 = Kyber(.{ .name = "ML-KEM-512", .ml_kem = true, .k = 2, .eta1 = 3, .du = 10, .dv = 4, }); |
toBytes()Generate a new, random key pair. |
pub const MLKem768 = Kyber(.{ .name = "ML-KEM-768", .ml_kem = true, .k = 3, .eta1 = 2, .du = 10, .dv = 4, }); |
fromBytes()Serializes the key into a byte array. |
pub const MLKem1024 = Kyber(.{ .name = "ML-KEM-1024", .ml_kem = true, .k = 4, .eta1 = 2, .du = 11, .dv = 5, }); }; |
SecretKeyDeserializes the key from a byte array. |
const modes = [_]type{ d00.Kyber512, d00.Kyber768, d00.Kyber1024, nist.MLKem512, nist.MLKem768, nist.MLKem1024, }; const h_length: usize = 32; const inner_seed_length: usize = 32; const common_encaps_seed_length: usize = 32; const common_shared_key_size: usize = 32; |
bytes_length: |
fn Kyber(comptime p: Params) type { return struct { // Size of a ciphertext, in bytes. pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv); |
decaps() |
const Self = @This(); const V = Vec(p.k); const M = Mat(p.k); |
toBytes() |
/// Length (in bytes) of a shared secret. pub const shared_length = common_shared_key_size; /// Length (in bytes) of a seed for deterministic encapsulation. pub const encaps_seed_length = common_encaps_seed_length; /// Length (in bytes) of a seed for key generation. pub const seed_length: usize = inner_seed_length + shared_length; /// Algorithm name. pub const name = p.name; |
fromBytes() |
/// A shared secret, and an encapsulated (encrypted) representation of it. pub const EncapsulatedSecret = struct { shared_secret: [shared_length]u8, ciphertext: [ciphertext_length]u8, }; |
KeyPair |
/// A Kyber public key. pub const PublicKey = struct { pk: InnerPk, |
generateDeterministic() |
// Cached hpk: [h_length]u8, // H(pk) |
generate() |
/// Size of a serialized representation of the key, in bytes. pub const bytes_length = InnerPk.bytes_length; |
Test:invNTTReductions bounds |
/// Generates a shared secret, and encapsulates it for the public key. /// If `seed` is `null`, a random seed is used. This is recommended. /// If `seed` is set, encapsulation is deterministic. pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret { var m: [inner_plaintext_length]u8 = undefined; |
Test:Test montReduce |
if (seed_) |seed| { if (p.ml_kem) { @memcpy(&m, &seed); } else { // m = H(seed) sha3.Sha3_256.hash(&seed, &m, .{}); } } else { crypto.random.bytes(&m); } |
Test:Test feToMont |
// (K', r) = G(m ‖ H(pk)) var kr: [inner_plaintext_length + h_length]u8 = undefined; var g = sha3.Sha3_512.init(.{}); g.update(&m); g.update(&pk.hpk); g.final(&kr); |
Test:Test Barrett reduction |
// c = innerEncrypt(pk, m, r) const ct = pk.pk.encrypt(&m, kr[32..64]); |
Test:Test csubq |
if (p.ml_kem) { return EncapsulatedSecret{ .shared_secret = kr[0..shared_length].*, // ML-KEM: K = K' .ciphertext = ct, }; } else { // Compute H(c) and put in second slot of kr, which will be (K', H(c)). sha3.Sha3_256.hash(&ct, kr[32..], .{}); |
Test:MulHat |
var ss: [shared_length]u8 = undefined; sha3.Shake256.hash(&kr, &ss, .{}); return EncapsulatedSecret{ .shared_secret = ss, // Kyber: K = KDF(K' ‖ H(c)) .ciphertext = ct, }; } } |
Test:NTT |
/// Serializes the key into a byte array. pub fn toBytes(pk: PublicKey) [bytes_length]u8 { return pk.pk.toBytes(); } |
Test:Compression |
/// Deserializes the key from a byte array. pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey { var ret: PublicKey = undefined; ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]); sha3.Sha3_256.hash(buf, &ret.hpk, .{}); return ret; } }; |
Test:noise |
/// A Kyber secret key. pub const SecretKey = struct { sk: InnerSk, pk: InnerPk, hpk: [h_length]u8, // H(pk) z: [shared_length]u8, |
Test:uniform sampling |
/// Size of a serialized representation of the key, in bytes. pub const bytes_length: usize = InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length; |
Test:Polynomial packing |
/// Decapsulates the shared secret within ct using the private key. pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 { // m' = innerDec(ct) const m2 = sk.sk.decrypt(ct); |
Test:Test inner PKE |
// (K'', r') = G(m' ‖ H(pk)) var kr2: [64]u8 = undefined; var g = sha3.Sha3_512.init(.{}); g.update(&m2); g.update(&sk.hpk); g.final(&kr2); |
Test:Test happy flow |
// ct' = innerEnc(pk, m', r') const ct2 = sk.pk.encrypt(&m2, kr2[32..64]); |
Test:NIST KAT test |
// Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)). sha3.Sha3_256.hash(ct, kr2[32..], .{}); // Replace K'' by z in the first slot of kr2 if ct ≠ ct'. cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2)); if (p.ml_kem) { // ML-KEM: K = K''/z return kr2[0..shared_length].*; } else { // Kyber: K = KDF(K''/z ‖ H(c)) var ss: [shared_length]u8 = undefined; sha3.Shake256.hash(&kr2, &ss, .{}); return ss; } } /// Serializes the key into a byte array. pub fn toBytes(sk: SecretKey) [bytes_length]u8 { return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z; } /// Deserializes the key from a byte array. pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey { var ret: SecretKey = undefined; comptime var s: usize = 0; ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]); s += InnerSk.bytes_length; ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]); s += InnerPk.bytes_length; ret.hpk = buf[s..][0..h_length].*; s += h_length; ret.z = buf[s..][0..shared_length].*; return ret; } }; /// A Kyber key pair. pub const KeyPair = struct { secret_key: SecretKey, public_key: PublicKey, /// Deterministically derive a key pair from a cryptograpically secure secret seed. /// /// Except in tests, applications should generally call `generate()` instead of this function. pub fn generateDeterministic(seed: [seed_length]u8) !KeyPair { var ret: KeyPair = undefined; ret.secret_key.z = seed[inner_seed_length..seed_length].*; // Generate inner key innerKeyFromSeed( seed[0..inner_seed_length].*, &ret.public_key.pk, &ret.secret_key.sk, ); ret.secret_key.pk = ret.public_key.pk; // Copy over z from seed. ret.secret_key.z = seed[inner_seed_length..seed_length].*; // Compute H(pk) sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{}); ret.public_key.hpk = ret.secret_key.hpk; return ret; } /// Generate a new, random key pair. pub fn generate() KeyPair { var random_seed: [seed_length]u8 = undefined; while (true) { crypto.random.bytes(&random_seed); return generateDeterministic(random_seed) catch { @branchHint(.unlikely); continue; }; } } }; // Size of plaintexts of the in const inner_plaintext_length: usize = Poly.compressedSize(1); const InnerPk = struct { rho: [32]u8, // ρ, the seed for the matrix A th: V, // NTT(t), normalized // Cached values aT: M, const bytes_length = V.bytes_length + 32; fn encrypt( pk: InnerPk, pt: *const [inner_plaintext_length]u8, seed: *const [32]u8, ) [ciphertext_length]u8 { // Sample r, e₁ and e₂ appropriately const rh = V.noise(p.eta1, 0, seed).ntt().barrettReduce(); const e1 = V.noise(eta2, p.k, seed); const e2 = Poly.noise(eta2, 2 * p.k, seed); // Next we compute u = Aᵀ r + e₁. First Aᵀ. var u: V = undefined; for (0..p.k) |i| { // Note that coefficients of r are bounded by q and those of Aᵀ // are bounded by 4.5q and so their product is bounded by 2¹⁵q // as required for multiplication. u.ps[i] = pk.aT.vs[i].dotHat(rh); } // Aᵀ and r were not in Montgomery form, so the Montgomery // multiplications in the inner product added a factor R⁻¹ which // the InvNTT cancels out. u = u.barrettReduce().invNTT().add(e1).normalize(); // Next, compute v = <t, r> + e₂ + Decompress_q(m, 1) const v = pk.th.dotHat(rh).barrettReduce().invNTT() .add(Poly.decompress(1, pt)).add(e2).normalize(); return u.compress(p.du) ++ v.compress(p.dv); } fn toBytes(pk: InnerPk) [bytes_length]u8 { return pk.th.toBytes() ++ pk.rho; } fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk { var ret: InnerPk = undefined; const th_bytes = buf[0..V.bytes_length]; ret.th = V.fromBytes(th_bytes).normalize(); if (p.ml_kem) { // Verify that the coefficients used a canonical representation. if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) { return error.NonCanonical; } } ret.rho = buf[V.bytes_length..bytes_length].*; ret.aT = M.uniform(ret.rho, true); return ret; } }; // Private key of the inner PKE const InnerSk = struct { sh: V, // NTT(s), normalized const bytes_length = V.bytes_length; fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 { const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]); const v = Poly.decompress( p.dv, ct[comptime V.compressedSize(p.du)..ciphertext_length], ); // Compute m = v - <s, u> return v.sub(sk.sh.dotHat(u.ntt()).barrettReduce().invNTT()) .normalize().compress(1); } fn toBytes(sk: InnerSk) [bytes_length]u8 { return sk.sh.toBytes(); } fn fromBytes(buf: *const [bytes_length]u8) InnerSk { var ret: InnerSk = undefined; ret.sh = V.fromBytes(buf).normalize(); return ret; } }; // Derives inner PKE keypair from given seed. fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void { var expanded_seed: [64]u8 = undefined; var h = sha3.Sha3_512.init(.{}); if (p.ml_kem) h.update(&[1]u8{p.k}); h.update(&seed); h.final(&expanded_seed); pk.rho = expanded_seed[0..32].*; const sigma = expanded_seed[32..64]; pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on // Sample secret vector s. sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize(); const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e. var th: V = undefined; // Next, we compute t = A s + e. for (0..p.k) |i| { // Note that coefficients of s are bounded by q and those of A // are bounded by 4.5q and so their product is bounded by 2¹⁵q // as required for multiplication. // A and s were not in Montgomery form, so the Montgomery // multiplications in the inner product added a factor R⁻¹ which // we'll cancel out with toMont(). This will also ensure the // coefficients of th are bounded in absolute value by q. th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont(); } pk.th = th.add(eh).normalize(); // bounded by 8q pk.aT = pk.aT.transpose(); } }; } // R mod q const r_mod_q: i32 = @rem(@as(i32, R), Q); // R² mod q const r2_mod_q: i32 = @rem(r_mod_q * r_mod_q, Q); // ζ is the degree 256 primitive root of unity used for the NTT. const zeta: i16 = 17; // (128)⁻¹ R². Used in inverse NTT. const r2_over_128: i32 = @mod(invertMod(128, Q) * r2_mod_q, Q); // zetas lists precomputed powers of the primitive root of unity in // Montgomery representation used for the NTT: // // zetas[i] = ζᵇʳᵛ⁽ⁱ⁾ R mod q // // where ζ = 17, brv(i) is the bitreversal of a 7-bit number and R=2¹⁶ mod q. const zetas = computeZetas(); // invNTTReductions keeps track of which coefficients to apply Barrett // reduction to in Poly.invNTT(). // // Generated lazily: once a butterfly is computed which is about to // overflow the i16, the largest coefficient is reduced. If that is // not enough, the other coefficient is reduced as well. // // This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf // TODO generate comptime? const inv_ntt_reductions = [_]i16{ -1, // after layer 1 -1, // after layer 2 16, 17, 48, 49, 80, 81, 112, 113, 144, 145, 176, 177, 208, 209, 240, 241, -1, // after layer 3 0, 1, 32, 33, 34, 35, 64, 65, 96, 97, 98, 99, 128, 129, 160, 161, 162, 163, 192, 193, 224, 225, 226, 227, -1, // after layer 4 2, 3, 66, 67, 68, 69, 70, 71, 130, 131, 194, 195, 196, 197, 198, 199, -1, // after layer 5 4, 5, 6, 7, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, -1, // after layer 6 -1, // after layer 7 }; test "invNTTReductions bounds" { // Checks whether the reductions proposed by invNTTReductions // don't overflow during invNTT(). var xs = [_]i32{1} ** 256; // start at |x| ≤ q var r: usize = 0; var layer: math.Log2Int(usize) = 1; while (layer < 8) : (layer += 1) { const w = @as(usize, 1) << layer; var i: usize = 0; while (i + w < 256) { xs[i] = xs[i] + xs[i + w]; try testing.expect(xs[i] <= 9); // we can't exceed 9q xs[i + w] = 1; i += 1; if (@mod(i, w) == 0) { i += w; } } while (true) { const j = inv_ntt_reductions[r]; r += 1; if (j < 0) { break; } xs[@as(usize, @intCast(j))] = 1; } } } // Extended euclidean algorithm. // // For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute // modular inverse. fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) { if (a == 0) { return .{ .gcd = b, .x = 0, .y = 1 }; } const r = eea(@rem(b, a), a); return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x }; } fn EeaResult(comptime T: type) type { return struct { gcd: T, x: T, y: T }; } // Returns least common multiple of a and b. fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) { const r = eea(a, b); return a * b / r.gcd; } // Invert modulo p. fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) { const r = eea(a, p); assert(r.gcd == 1); return r.x; } // Reduce mod q for testing. fn modQ32(x: i32) i16 { var y = @as(i16, @intCast(@rem(x, @as(i32, Q)))); if (y < 0) { y += Q; } return y; } // Given -2¹⁵ q ≤ x < 2¹⁵ q, returns -q < y < q with x 2⁻¹⁶ = y (mod q). fn montReduce(x: i32) i16 { const qInv = comptime invertMod(@as(i32, Q), R); // This is Montgomery reduction with R=2¹⁶. // // Note gcd(2¹⁶, q) = 1 as q is prime. Write q' := 62209 = q⁻¹ mod R. // First we compute // // m := ((x mod R) q') mod R // = x q' mod R // = int16(x q') // = int16(int32(x) * int32(q')) // // Note that x q' might be as big as 2³² and could overflow the int32 // multiplication in the last line. However for any int32s a and b, // we have int32(int64(a)*int64(b)) = int32(a*b) and so the result is ok. const m: i16 = @truncate(@as(i32, @truncate(x *% qInv))); // Note that x - m q is divisible by R; indeed modulo R we have // // x - m q ≡ x - x q' q ≡ x - x q⁻¹ q ≡ x - x = 0. // // We return y := (x - m q) / R. Note that y is indeed correct as // modulo q we have // // y ≡ x R⁻¹ - m q R⁻¹ = x R⁻¹ // // and as both 2¹⁵ q ≤ m q, x < 2¹⁵ q, we have // 2¹⁶ q ≤ x - m q < 2¹⁶ and so q ≤ (x - m q) / R < q as desired. const yR = x - @as(i32, m) * @as(i32, Q); return @bitCast(@as(u16, @truncate(@as(u32, @bitCast(yR)) >> 16))); } test "Test montReduce" { var rnd = RndGen.init(0); for (0..1000) |_| { const bound = comptime @as(i32, Q) * (1 << 15); const x = rnd.random().intRangeLessThan(i32, -bound, bound); const y = montReduce(x); try testing.expect(-Q < y and y < Q); try testing.expectEqual(modQ32(x), modQ32(@as(i32, y) * R)); } } // Given any x, return x R mod q where R=2¹⁶. fn feToMont(x: i16) i16 { // Note |1353 x| ≤ 1353 2¹⁵ ≤ 13318 q ≤ 2¹⁵ q and so we're within // the bounds of montReduce. return montReduce(@as(i32, x) * r2_mod_q); } test "Test feToMont" { var x: i32 = -(1 << 15); while (x < 1 << 15) : (x += 1) { const y = feToMont(@as(i16, @intCast(x))); try testing.expectEqual(modQ32(@as(i32, y)), modQ32(x * r_mod_q)); } } // Given any x, compute 0 ≤ y ≤ q with x = y (mod q). // // Beware: we might have feBarrettReduce(x) = q ≠ 0 for some x. In fact, // this happens if and only if x = -nq for some positive integer n. fn feBarrettReduce(x: i16) i16 { // This is standard Barrett reduction. // // For any x we have x mod q = x - ⌊x/q⌋ q. We will use 20159/2²⁶ as // an approximation of 1/q. Note that 0 ≤ 20159/2²⁶ - 1/q ≤ 0.135/2²⁶ // and so | x 20156/2²⁶ - x/q | ≤ 2⁻¹⁰ for |x| ≤ 2¹⁶. For all x // not a multiple of q, the number x/q is further than 1/q from any integer // and so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋. If x is a multiple of q and x is positive, // then x 20156/2²⁶ is larger than x/q so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋ as well. // Finally, if x is negative multiple of q, then ⌊x 20156/2²⁶⌋ = ⌊x/q⌋-1. // Thus // [ q if x=-nq for pos. integer n // x - ⌊x 20156/2²⁶⌋ q = [ // [ x mod q otherwise // // To actually compute this, note that // // ⌊x 20156/2²⁶⌋ = (20159 x) >> 26. return x -% @as(i16, @intCast((@as(i32, x) * 20159) >> 26)) *% Q; } test "Test Barrett reduction" { var x: i32 = -(1 << 15); while (x < 1 << 15) : (x += 1) { var y1 = feBarrettReduce(@as(i16, @intCast(x))); const y2 = @mod(@as(i16, @intCast(x)), Q); if (x < 0 and @rem(-x, Q) == 0) { y1 -= Q; } try testing.expectEqual(y1, y2); } } // Returns x if x < q and x - q otherwise. Assumes x ≥ -29439. fn csubq(x: i16) i16 { var r = x; r -= Q; r += (r >> 15) & Q; return r; } test "Test csubq" { var x: i32 = -29439; while (x < 1 << 15) : (x += 1) { const y1 = csubq(@as(i16, @intCast(x))); var y2 = @as(i16, @intCast(x)); if (@as(i16, @intCast(x)) >= Q) { y2 -= Q; } try testing.expectEqual(y1, y2); } } // Compute a^s mod p. fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) { var ret: @TypeOf(a) = 1; var s2 = s; var a2 = a; while (true) { if (s2 & 1 == 1) { ret = @mod(ret * a2, p); } s2 >>= 1; if (s2 == 0) { break; } a2 = @mod(a2 * a2, p); } return ret; } // Computes zetas table used by ntt and invNTT. fn computeZetas() [128]i16 { @setEvalBranchQuota(10000); var ret: [128]i16 = undefined; for (&ret, 0..) |*r, i| { const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q))); r.* = csubq(feBarrettReduce(feToMont(t))); } return ret; } // An element of our base ring R which are polynomials over ℤ_q // modulo the equation Xᴺ = -1, where q=3329 and N=256. // // This type is also used to store NTT-transformed polynomials, // see Poly.NTT(). // // Coefficients aren't always reduced. See Normalize(). const Poly = struct { cs: [N]i16, const bytes_length = N / 2 * 3; const zero: Poly = .{ .cs = .{0} ** N }; fn add(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = a.cs[i] + b.cs[i]; } return ret; } fn sub(a: Poly, b: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = a.cs[i] - b.cs[i]; } return ret; } // For testing, generates a random polynomial with for each // coefficient |x| ≤ q. fn randAbsLeqQ(rnd: anytype) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q); } return ret; } // For testing, generates a random normalized polynomial. fn randNormalized(rnd: anytype) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q); } return ret; } // Executes a forward "NTT" on p. // // Assumes the coefficients are in absolute value ≤q. The resulting // coefficients are in absolute value ≤7q. If the input is in Montgomery // form, then the result is in Montgomery form and so (by linearity of the NTT) // if the input is in regular form, then the result is also in regular form. fn ntt(a: Poly) Poly { // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512 // does not divide into q-1) and so we cannot do a regular NTT. ℤ_q // does have a primitive 256ᵗʰ root of unity, the smallest of which // is ζ := 17. // // Recall that our base ring R := ℤ_q[x] / (x²⁵⁶ + 1). The polynomial // x²⁵⁶+1 will not split completely (as its roots would be 512ᵗʰ roots // of unity.) However, it does split almost (using ζ¹²⁸ = -1): // // x²⁵⁶ + 1 = (x²)¹²⁸ - ζ¹²⁸ // = ((x²)⁶⁴ - ζ⁶⁴)((x²)⁶⁴ + ζ⁶⁴) // = ((x²)³² - ζ³²)((x²)³² + ζ³²)((x²)³² - ζ⁹⁶)((x²)³² + ζ⁹⁶) // ⋮ // = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷) // // Note that the powers of ζ that appear (from the second line down) are // in binary // // 0100000 1100000 // 0010000 1010000 0110000 1110000 // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000 // … // // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit // bitreversal of x. These powers of ζ are given by the Zetas array. // // The polynomials x² ± ζⁱ are irreducible and coprime, hence by // the Chinese Remainder Theorem we know // // ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷) // // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ ) // is an isomorphism, which is the "NTT". It can be efficiently computed by // // // a ↦ ( a mod (x²)⁶⁴ - ζ⁶⁴, a mod (x²)⁶⁴ + ζ⁶⁴ ) // ↦ ( a mod (x²)³² - ζ³², a mod (x²)³² + ζ³², // a mod (x²)⁹⁶ - ζ⁹⁶, a mod (x²)⁹⁶ + ζ⁹⁶ ) // // et cetera // If N was 8 then this can be pictured in the following diagram: // // https://cnx.org/resources/17ee4dfe517a6adda05377b25a00bf6e6c93c334/File0026.png // // Each cross is a Cooley-Tukey butterfly: it's the map // // (a, b) ↦ (a + ζb, a - ζb) // // for the appropriate power ζ for that column and row group. var p = a; var k: usize = 0; // index into zetas var l = N >> 1; while (l > 1) : (l >>= 1) { // On the nᵗʰ iteration of the l-loop, the absolute value of the // coefficients are bounded by nq. // offset effectively loops over the row groups in this column; it is // the first row in the row group. var offset: usize = 0; while (offset < N - l) : (offset += 2 * l) { k += 1; const z = @as(i32, zetas[k]); // j loops over each butterfly in the row group. for (offset..offset + l) |j| { const t = montReduce(z * @as(i32, p.cs[j + l])); p.cs[j + l] = p.cs[j] - t; p.cs[j] += t; } } } return p; } // Executes an inverse "NTT" on p and multiply by the Montgomery factor R. // // Assumes the coefficients are in absolute value ≤q. The resulting // coefficients are in absolute value ≤q. If the input is in Montgomery // form, then the result is in Montgomery form and so (by linearity) // if the input is in regular form, then the result is also in regular form. fn invNTT(a: Poly) Poly { var k: usize = 127; // index into zetas var r: usize = 0; // index into invNTTReductions var p = a; // We basically do the oppposite of NTT, but postpone dividing by 2 in the // inverse of the Cooley-Tukey butterfly and accumulate that into a big // division by 2⁷ at the end. See the comments in the ntt() function. var l: usize = 2; while (l < N) : (l <<= 1) { var offset: usize = 0; while (offset < N - l) : (offset += 2 * l) { // As we're inverting, we need powers of ζ⁻¹ (instead of ζ). // To be precise, we need ζᵇʳᵛ⁽ᵏ⁾⁻¹²⁸. However, as ζ⁻¹²⁸ = -1, // we can use the existing zetas table instead of // keeping a separate invZetas table as in Dilithium. const minZeta = @as(i32, zetas[k]); k -= 1; for (offset..offset + l) |j| { // Gentleman-Sande butterfly: (a, b) ↦ (a + b, ζ(a-b)) const t = p.cs[j + l] - p.cs[j]; p.cs[j] += p.cs[j + l]; p.cs[j + l] = montReduce(minZeta * @as(i32, t)); // Note that if we had |a| < αq and |b| < βq before the // butterfly, then now we have |a| < (α+β)q and |b| < q. } } // We let the invNTTReductions instruct us which coefficients to // Barrett reduce. while (true) { const i = inv_ntt_reductions[r]; r += 1; if (i < 0) { break; } p.cs[@as(usize, @intCast(i))] = feBarrettReduce(p.cs[@as(usize, @intCast(i))]); } } for (0..N) |j| { // Note 1441 = (128)⁻¹ R². The coefficients are bounded by 9q, so // as 1441 * 9 ≈ 2¹⁴ < 2¹⁵, we're within the required bounds // for montReduce(). p.cs[j] = montReduce(r2_over_128 * @as(i32, p.cs[j])); } return p; } // Normalizes coefficients. // // Ensures each coefficient is in {0, …, q-1}. fn normalize(a: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = csubq(feBarrettReduce(a.cs[i])); } return ret; } // Put p in Montgomery form. fn toMont(a: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = feToMont(a.cs[i]); } return ret; } // Barret reduce coefficients. // // Beware, this does not fully normalize coefficients. fn barrettReduce(a: Poly) Poly { var ret: Poly = undefined; for (0..N) |i| { ret.cs[i] = feBarrettReduce(a.cs[i]); } return ret; } fn compressedSize(comptime d: u8) usize { return @divTrunc(N * d, 8); } // Returns packed Compress_q(p, d). // // Assumes p is normalized. fn compress(p: Poly, comptime d: u8) [compressedSize(d)]u8 { @setEvalBranchQuota(10000); const q_over_2: u32 = comptime @divTrunc(Q, 2); // (q-1)/2 const two_d_min_1: u32 = comptime (1 << d) - 1; // 2ᵈ-1 var in_off: usize = 0; var out_off: usize = 0; const batch_size: usize = comptime lcm(@as(i16, d), 8); const in_batch_size: usize = comptime batch_size / d; const out_batch_size: usize = comptime batch_size / 8; const out_length: usize = comptime @divTrunc(N * d, 8); comptime assert(out_length * 8 == d * N); var out = [_]u8{0} ** out_length; while (in_off < N) { // First we compress into in. var in: [in_batch_size]u16 = undefined; inline for (0..in_batch_size) |i| { // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ // = DIV((x << d) + q/2, q) & ((1<<d) - 1) const t = @as(u24, @intCast(p.cs[in_off + i])) << d; // Division by invariant multiplication, equivalent to DIV(t + q/2, q). // A division may not be a constant-time operation, even with a constant denominator. // Here, side channels would leak information about the shared secret, see https://kyberslash.cr.yp.to // Multiplication, on the other hand, is a constant-time operation on the CPUs we currently support. comptime assert(d <= 11); comptime assert(((20642679 * @as(u64, Q)) >> 36) == 1); const u: u32 = @intCast((@as(u64, t + q_over_2) * 20642679) >> 36); in[i] = @intCast(u & two_d_min_1); } // Now we pack the d-bit integers from `in' into out as bytes. comptime var in_shift: usize = 0; comptime var j: usize = 0; comptime var i: usize = 0; inline while (i < in_batch_size) : (j += 1) { comptime var todo: usize = 8; inline while (todo > 0) { const out_shift = comptime 8 - todo; out[out_off + j] |= @as(u8, @truncate((in[i] >> in_shift) << out_shift)); const done = comptime @min(@min(d, todo), d - in_shift); todo -= done; in_shift += done; if (in_shift == d) { in_shift = 0; i += 1; } } } in_off += in_batch_size; out_off += out_batch_size; } return out; } // Set p to Decompress_q(m, d). fn decompress(comptime d: u8, in: *const [compressedSize(d)]u8) Poly { @setEvalBranchQuota(10000); const in_len = comptime @divTrunc(N * d, 8); comptime assert(in_len * 8 == d * N); var ret: Poly = undefined; var in_off: usize = 0; var out_off: usize = 0; const batch_size: usize = comptime lcm(@as(i16, d), 8); const in_batch_size: usize = comptime batch_size / 8; const out_batch_size: usize = comptime batch_size / d; while (out_off < N) { comptime var in_shift: usize = 0; comptime var j: usize = 0; comptime var i: usize = 0; inline while (i < out_batch_size) : (i += 1) { // First, unpack next coefficient. comptime var todo = d; var out: u16 = 0; inline while (todo > 0) { const out_shift = comptime d - todo; const m = comptime (1 << d) - 1; out |= (@as(u16, in[in_off + j] >> in_shift) << out_shift) & m; const done = comptime @min(@min(8, todo), 8 - in_shift); todo -= done; in_shift += done; if (in_shift == 8) { in_shift = 0; j += 1; } } // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋ // = ⌊(q/2ᵈ)x+½⌋ // = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋ // = (qx + (1<<(d-1))) >> d const qx = @as(u32, out) * @as(u32, Q); ret.cs[out_off + i] = @as(i16, @intCast((qx + (1 << (d - 1))) >> d)); } in_off += in_batch_size; out_off += out_batch_size; } return ret; } // Returns the "pointwise" multiplication a o b. // // That is: invNTT(a o b) = invNTT(a) * invNTT(b). Assumes a and b are in // Montgomery form. Products between coefficients of a and b must be strictly // bounded in absolute value by 2¹⁵q. a o b will be in Montgomery form and // bounded in absolute value by 2q. fn mulHat(a: Poly, b: Poly) Poly { // Recall from the discussion in ntt(), that a transformed polynomial is // an element of ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷); // that is: 128 degree-one polynomials instead of simply 256 elements // from ℤ_q as in the regular NTT. So instead of pointwise multiplication, // we multiply the 128 pairs of degree-one polynomials modulo the // right equation: // // (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x, // // where ζ' is the appropriate power of ζ. var p: Poly = undefined; var k: usize = 64; var i: usize = 0; while (i < N) : (i += 4) { const z = @as(i32, zetas[k]); k += 1; const a1b1 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i + 1])); const a0b0 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i])); const a1b0 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i])); const a0b1 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i + 1])); p.cs[i] = montReduce(a1b1 * z) + a0b0; p.cs[i + 1] = a0b1 + a1b0; const a3b3 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 3])); const a2b2 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 2])); const a3b2 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 2])); const a2b3 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 3])); p.cs[i + 2] = a2b2 - montReduce(a3b3 * z); p.cs[i + 3] = a2b3 + a3b2; } return p; } // Sample p from a centered binomial distribution with n=2η and p=½ - viz: // coefficients are in {-η, …, η} with probabilities // // {ncr(0, 2η)/2^2η, ncr(1, 2η)/2^2η, …, ncr(2η,2η)/2^2η} fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Poly { var h = sha3.Shake256.init(.{}); const suffix: [1]u8 = .{nonce}; h.update(seed); h.update(&suffix); // The distribution at hand is exactly the same as that // of (a₁ + a₂ + … + a_η) - (b₁ + … + b_η) where a_i,b_i~U(1). // Thus we need 2η bits per coefficient. const buf_len = comptime 2 * eta * N / 8; var buf: [buf_len]u8 = undefined; h.squeeze(&buf); // buf is interpreted as a₁…a_ηb₁…b_ηa₁…a_ηb₁…b_η…. We process // multiple coefficients in one batch. const T = switch (builtin.target.cpu.arch) { .x86_64, .x86 => u32, // Generates better code on Intel CPUs else => u64, // u128 might be faster on some other CPUs. }; comptime var batch_count: usize = undefined; comptime var batch_bytes: usize = undefined; comptime var mask: T = 0; comptime { batch_count = @bitSizeOf(T) / @as(usize, 2 * eta); while (@rem(N, batch_count) != 0 and batch_count > 0) : (batch_count -= 1) {} assert(batch_count > 0); assert(@rem(2 * eta * batch_count, 8) == 0); batch_bytes = 2 * eta * batch_count / 8; for (0..2 * eta * batch_count) |_| { mask <<= eta; mask |= 1; } } var ret: Poly = undefined; for (0..comptime N / batch_count) |i| { // Read coefficients into t. In the case of η=3, // we have t = a₁ + 2a₂ + 4a₃ + 8b₁ + 16b₂ + … var t: T = 0; inline for (0..batch_bytes) |j| { t |= @as(T, buf[batch_bytes * i + j]) << (8 * j); } // Accumulate `a's and `b's together by masking them out, shifting // and adding. For η=3, we have d = a₁ + a₂ + a₃ + 8(b₁ + b₂ + b₃) + … var d: T = 0; inline for (0..eta) |j| { d += (t >> j) & mask; } // Extract each a and b separately and set coefficient in polynomial. inline for (0..batch_count) |j| { const mask2 = comptime (1 << eta) - 1; const a = @as(i16, @intCast((d >> (comptime (2 * j * eta))) & mask2)); const b = @as(i16, @intCast((d >> (comptime ((2 * j + 1) * eta))) & mask2)); ret.cs[batch_count * i + j] = a - b; } } return ret; } // Sample p uniformly from the given seed and x and y coordinates. fn uniform(seed: [32]u8, x: u8, y: u8) Poly { var h = sha3.Shake128.init(.{}); const suffix: [2]u8 = .{ x, y }; h.update(&seed); h.update(&suffix); const buf_len = sha3.Shake128.block_length; // rate SHAKE-128 var buf: [buf_len]u8 = undefined; var ret: Poly = undefined; var i: usize = 0; // index into ret.cs outer: while (true) { h.squeeze(&buf); var j: usize = 0; // index into buf while (j < buf_len) : (j += 3) { const b0 = @as(u16, buf[j]); const b1 = @as(u16, buf[j + 1]); const b2 = @as(u16, buf[j + 2]); const ts: [2]u16 = .{ b0 | ((b1 & 0xf) << 8), (b1 >> 4) | (b2 << 4), }; inline for (ts) |t| { if (t < Q) { ret.cs[i] = @as(i16, @intCast(t)); i += 1; if (i == N) { break :outer; } } } } } return ret; } // Packs p. // // Assumes p is normalized (and not just Barrett reduced). fn toBytes(p: Poly) [bytes_length]u8 { var ret: [bytes_length]u8 = undefined; for (0..comptime N / 2) |i| { const t0 = @as(u16, @intCast(p.cs[2 * i])); const t1 = @as(u16, @intCast(p.cs[2 * i + 1])); ret[3 * i] = @as(u8, @truncate(t0)); ret[3 * i + 1] = @as(u8, @truncate((t0 >> 8) | (t1 << 4))); ret[3 * i + 2] = @as(u8, @truncate(t1 >> 4)); } return ret; } // Unpacks a Poly from buf. // // p will not be normalized; instead 0 ≤ p[i] < 4096. fn fromBytes(buf: *const [bytes_length]u8) Poly { var ret: Poly = undefined; for (0..comptime N / 2) |i| { const b0 = @as(i16, buf[3 * i]); const b1 = @as(i16, buf[3 * i + 1]); const b2 = @as(i16, buf[3 * i + 2]); ret.cs[2 * i] = b0 | ((b1 & 0xf) << 8); ret.cs[2 * i + 1] = (b1 >> 4) | b2 << 4; } return ret; } }; // A vector of K polynomials. fn Vec(comptime K: u8) type { return struct { ps: [K]Poly, const Self = @This(); const bytes_length = K * Poly.bytes_length; fn compressedSize(comptime d: u8) usize { return Poly.compressedSize(d) * K; } fn ntt(a: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].ntt(); } return ret; } fn invNTT(a: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].invNTT(); } return ret; } fn normalize(a: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].normalize(); } return ret; } fn barrettReduce(a: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].barrettReduce(); } return ret; } fn add(a: Self, b: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].add(b.ps[i]); } return ret; } fn sub(a: Self, b: Self) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = a.ps[i].sub(b.ps[i]); } return ret; } // Samples v[i] from centered binomial distribution with the given η, // seed and nonce+i. fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self { var ret: Self = undefined; for (0..K) |i| { ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed); } return ret; } // Sets p to the inner product of a and b using "pointwise" multiplication. // // See MulHat() and NTT() for a description of the multiplication. // Assumes a and b are in Montgomery form. p will be in Montgomery form, // and its coefficients will be bounded in absolute value by 2kq. // If a and b are not in Montgomery form, then the action is the same // as "pointwise" multiplication followed by multiplying by R⁻¹, the inverse // of the Montgomery factor. fn dotHat(a: Self, b: Self) Poly { var ret: Poly = Poly.zero; for (0..K) |i| { ret = ret.add(a.ps[i].mulHat(b.ps[i])); } return ret; } fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 { const cs = comptime Poly.compressedSize(d); var ret: [compressedSize(d)]u8 = undefined; inline for (0..K) |i| { ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d); } return ret; } fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self { const cs = comptime Poly.compressedSize(d); var ret: Self = undefined; inline for (0..K) |i| { ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]); } return ret; } /// Serializes the key into a byte array. fn toBytes(v: Self) [bytes_length]u8 { var ret: [bytes_length]u8 = undefined; inline for (0..K) |i| { ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes(); } return ret; } /// Deserializes the key from a byte array. fn fromBytes(buf: *const [bytes_length]u8) Self { var ret: Self = undefined; inline for (0..K) |i| { ret.ps[i] = Poly.fromBytes( buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length], ); } return ret; } }; } // A matrix of K vectors fn Mat(comptime K: u8) type { return struct { const Self = @This(); vs: [K]Vec(K), fn uniform(seed: [32]u8, comptime transposed: bool) Self { var ret: Self = undefined; var i: u8 = 0; while (i < K) : (i += 1) { var j: u8 = 0; while (j < K) : (j += 1) { ret.vs[i].ps[j] = Poly.uniform( seed, if (transposed) i else j, if (transposed) j else i, ); } } return ret; } // Returns transpose of A fn transpose(m: Self) Self { var ret: Self = undefined; for (0..K) |i| { for (0..K) |j| { ret.vs[i].ps[j] = m.vs[j].ps[i]; } } return ret; } }; } // Returns `true` if a ≠ b. fn ctneq(comptime len: usize, a: [len]u8, b: [len]u8) u1 { return 1 - @intFromBool(crypto.timing_safe.eql([len]u8, a, b)); } // Copy src into dst given b = 1. fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void { const mask = @as(u8, 0) -% b; for (0..len) |i| { dst[i] ^= mask & (dst[i] ^ src[i]); } } test "MulHat" { var rnd = RndGen.init(0); for (0..100) |_| { const a = Poly.randAbsLeqQ(&rnd); const b = Poly.randAbsLeqQ(&rnd); const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize(); var p: Poly = undefined; @memset(&p.cs, 0); for (0..N) |i| { for (0..N) |j| { var v = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[j])); var k = i + j; if (k >= N) { // Recall Xᴺ = -1. k -= N; v = -v; } p.cs[k] = feBarrettReduce(v + p.cs[k]); } } p = p.toMont().normalize(); try testing.expectEqual(p, p2); } } test "NTT" { var rnd = RndGen.init(0); for (0..1000) |_| { var p = Poly.randAbsLeqQ(&rnd); const q = p.toMont().normalize(); p = p.ntt(); for (0..N) |i| { try testing.expect(p.cs[i] <= 7 * Q and -7 * Q <= p.cs[i]); } p = p.normalize().invNTT(); for (0..N) |i| { try testing.expect(p.cs[i] <= Q and -Q <= p.cs[i]); } p = p.normalize(); try testing.expectEqual(p, q); } } test "Compression" { var rnd = RndGen.init(0); inline for (.{ 1, 4, 5, 10, 11 }) |d| { for (0..1000) |_| { const p = Poly.randNormalized(&rnd); const pp = p.compress(d); const pq = Poly.decompress(d, &pp).compress(d); try testing.expectEqual(pp, pq); } } } test "noise" { var seed: [32]u8 = undefined; for (&seed, 0..) |*s, i| { s.* = @as(u8, @intCast(i)); } try testing.expectEqual(Poly.noise(3, 37, &seed).cs, .{ 0, 0, 1, -1, 0, 2, 0, -1, -1, 3, 0, 1, -2, -2, 0, 1, -2, 1, 0, -2, 3, 0, 0, 0, 1, 3, 1, 1, 2, 1, -1, -1, -1, 0, 1, 0, 1, 0, 2, 0, 1, -2, 0, -1, -1, -2, 1, -1, -1, 2, -1, 1, 1, 2, -3, -1, -1, 0, 0, 0, 0, 1, -1, -2, -2, 0, -2, 0, 0, 0, 1, 0, -1, -1, 1, -2, 2, 0, 0, 2, -2, 0, 1, 0, 1, 1, 1, 0, 1, -2, -1, -2, -1, 1, 0, 0, 0, 0, 0, 1, 0, -1, -1, 0, -1, 1, 0, 1, 0, -1, -1, 0, -2, 2, 0, -2, 1, -1, 0, 1, -1, -1, 2, 1, 0, 0, -2, -1, 2, 0, 0, 0, -1, -1, 3, 1, 0, 1, 0, 1, 0, 2, 1, 0, 0, 1, 0, 1, 0, 0, -1, -1, -1, 0, 1, 3, 1, 0, 1, 0, 1, -1, -1, -1, -1, 0, 0, -2, -1, -1, 2, 0, 1, 0, 1, 0, 2, -2, 0, 1, 1, -3, -1, -2, -1, 0, 1, 0, 1, -2, 2, 2, 1, 1, 0, -1, 0, -1, -1, 1, 0, -1, 2, 1, -1, 1, 2, -2, 1, 2, 0, 1, 2, 1, 0, 0, 2, 1, 2, 1, 0, 2, 1, 0, 0, -1, -1, 1, -1, 0, 1, -1, 2, 2, 0, 0, -1, 1, 1, 1, 1, 0, 0, -2, 0, -1, 1, 2, 0, 0, 1, 1, -1, 1, 0, 1, }); try testing.expectEqual(Poly.noise(2, 37, &seed).cs, .{ 1, 0, 1, -1, -1, -2, -1, -1, 2, 0, -1, 0, 0, -1, 1, 1, -1, 1, 0, 2, -2, 0, 1, 2, 0, 0, -1, 1, 0, -1, 1, -1, 1, 2, 1, 1, 0, -1, 1, -1, -2, -1, 1, -1, -1, -1, 2, -1, -1, 0, 0, 1, 1, -1, 1, 1, 1, 1, -1, -2, 0, 1, 0, 0, 2, 1, -1, 2, 0, 0, 1, 1, 0, -1, 0, 0, -1, -1, 2, 0, 1, -1, 2, -1, -1, -1, -1, 0, -2, 0, 2, 1, 0, 0, 0, -1, 0, 0, 0, -1, -1, 0, -1, -1, 0, -1, 0, 0, -2, 1, 1, 0, 1, 0, 1, 0, 1, 1, -1, 2, 0, 1, -1, 1, 2, 0, 0, 0, 0, -1, -1, -1, 0, 1, 0, -1, 2, 0, 0, 1, 1, 1, 0, 1, -1, 1, 2, 1, 0, 2, -1, 1, -1, -2, -1, -2, -1, 1, 0, -2, -2, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, -1, -1, 0, 2, 0, 0, -2, 1, 0, 2, 1, -1, -2, 0, 0, -1, 1, 1, 0, 0, 2, 0, 1, 1, -2, 1, -2, 1, 1, 0, 2, 0, -1, 0, -1, 0, 1, 2, 0, 1, 0, -2, 1, -2, -2, 1, -1, 0, -1, 1, 1, 0, 0, 0, 1, 0, -1, 1, 1, 0, 0, 0, 0, 1, 0, 1, -1, 0, 1, -1, -1, 2, 0, 0, 1, -1, 0, 1, -1, 0, }); } test "uniform sampling" { var seed: [32]u8 = undefined; for (&seed, 0..) |*s, i| { s.* = @as(u8, @intCast(i)); } try testing.expectEqual(Poly.uniform(seed, 1, 0).cs, .{ 797, 993, 161, 6, 2608, 2385, 2096, 2661, 1676, 247, 2440, 342, 634, 194, 1570, 2848, 986, 684, 3148, 3208, 2018, 351, 2288, 612, 1394, 170, 1521, 3119, 58, 596, 2093, 1549, 409, 2156, 1934, 1730, 1324, 388, 446, 418, 1719, 2202, 1812, 98, 1019, 2369, 214, 2699, 28, 1523, 2824, 273, 402, 2899, 246, 210, 1288, 863, 2708, 177, 3076, 349, 44, 949, 854, 1371, 957, 292, 2502, 1617, 1501, 254, 7, 1761, 2581, 2206, 2655, 1211, 629, 1274, 2358, 816, 2766, 2115, 2985, 1006, 2433, 856, 2596, 3192, 1, 1378, 2345, 707, 1891, 1669, 536, 1221, 710, 2511, 120, 1176, 322, 1897, 2309, 595, 2950, 1171, 801, 1848, 695, 2912, 1396, 1931, 1775, 2904, 893, 2507, 1810, 2873, 253, 1529, 1047, 2615, 1687, 831, 1414, 965, 3169, 1887, 753, 3246, 1937, 115, 2953, 586, 545, 1621, 1667, 3187, 1654, 1988, 1857, 512, 1239, 1219, 898, 3106, 391, 1331, 2228, 3169, 586, 2412, 845, 768, 156, 662, 478, 1693, 2632, 573, 2434, 1671, 173, 969, 364, 1663, 2701, 2169, 813, 1000, 1471, 720, 2431, 2530, 3161, 733, 1691, 527, 2634, 335, 26, 2377, 1707, 767, 3020, 950, 502, 426, 1138, 3208, 2607, 2389, 44, 1358, 1392, 2334, 875, 2097, 173, 1697, 2578, 942, 1817, 974, 1165, 2853, 1958, 2973, 3282, 271, 1236, 1677, 2230, 673, 1554, 96, 242, 1729, 2518, 1884, 2272, 71, 1382, 924, 1807, 1610, 456, 1148, 2479, 2152, 238, 2208, 2329, 713, 1175, 1196, 757, 1078, 3190, 3169, 708, 3117, 154, 1751, 3225, 1364, 154, 23, 2842, 1105, 1419, 79, 5, 2013, }); } test "Polynomial packing" { var rnd = RndGen.init(0); for (0..1000) |_| { const p = Poly.randNormalized(&rnd); try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p); } } test "Test inner PKE" { var seed: [32]u8 = undefined; var pt: [32]u8 = undefined; for (&seed, &pt, 0..) |*s, *p, i| { s.* = @as(u8, @intCast(i)); p.* = @as(u8, @intCast(i + 32)); } inline for (modes) |mode| { for (0..10) |i| { var pk: mode.InnerPk = undefined; var sk: mode.InnerSk = undefined; seed[0] = @as(u8, @intCast(i)); mode.innerKeyFromSeed(seed, &pk, &sk); for (0..10) |j| { seed[1] = @as(u8, @intCast(j)); try testing.expectEqual(sk.decrypt(&pk.encrypt(&pt, &seed)), pt); } } } } test "Test happy flow" { var seed: [64]u8 = undefined; for (&seed, 0..) |*s, i| { s.* = @as(u8, @intCast(i)); } inline for (modes) |mode| { for (0..10) |i| { seed[0] = @as(u8, @intCast(i)); const kp = try mode.KeyPair.generateDeterministic(seed); const sk = try mode.SecretKey.fromBytes(&kp.secret_key.toBytes()); try testing.expectEqual(sk, kp.secret_key); const pk = try mode.PublicKey.fromBytes(&kp.public_key.toBytes()); try testing.expectEqual(pk, kp.public_key); for (0..10) |j| { seed[1] = @as(u8, @intCast(j)); const e = pk.encaps(seed[0..32].*); try testing.expectEqual(e.shared_secret, try sk.decaps(&e.ciphertext)); } } } } // Code to test NIST Known Answer Tests (KAT), see PQCgenKAT.c. const sha2 = crypto.hash.sha2; test "NIST KAT test" { inline for (.{ .{ d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" }, .{ d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" }, .{ d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" }, }) |modeHash| { const mode = modeHash[0]; var seed: [48]u8 = undefined; for (&seed, 0..) |*s, i| { s.* = @as(u8, @intCast(i)); } var f = sha2.Sha256.init(.{}); const fw = f.writer(); var g = NistDRBG.init(seed); try std.fmt.format(fw, "# {s}\n\n", .{mode.name}); for (0..100) |i| { g.fill(&seed); try std.fmt.format(fw, "count = {}\n", .{i}); try std.fmt.format(fw, "seed = {X}\n", .{&seed}); var g2 = NistDRBG.init(seed); // This is not equivalent to g2.fill(kseed[:]). As the reference // implementation calls randombytes twice generating the keypair, // we have to do that as well. var kseed: [64]u8 = undefined; var eseed: [32]u8 = undefined; g2.fill(kseed[0..32]); g2.fill(kseed[32..64]); g2.fill(&eseed); const kp = try mode.KeyPair.generateDeterministic(kseed); const e = kp.public_key.encaps(eseed); const ss2 = try kp.secret_key.decaps(&e.ciphertext); try testing.expectEqual(ss2, e.shared_secret); try std.fmt.format(fw, "pk = {X}\n", .{&kp.public_key.toBytes()}); try std.fmt.format(fw, "sk = {X}\n", .{&kp.secret_key.toBytes()}); try std.fmt.format(fw, "ct = {X}\n", .{&e.ciphertext}); try std.fmt.format(fw, "ss = {X}\n\n", .{&e.shared_secret}); } var out: [32]u8 = undefined; f.final(&out); var outHex: [64]u8 = undefined; _ = try std.fmt.bufPrint(&outHex, "{x}", .{&out}); try testing.expectEqual(outHex, modeHash[1].*); } } const NistDRBG = struct { key: [32]u8, v: [16]u8, fn incV(g: *NistDRBG) void { var j: usize = 15; while (j >= 0) : (j -= 1) { if (g.v[j] == 255) { g.v[j] = 0; } else { g.v[j] += 1; break; } } } // AES256_CTR_DRBG_Update(pd, &g.key, &g.v). fn update(g: *NistDRBG, pd: ?[48]u8) void { var buf: [48]u8 = undefined; const ctx = crypto.core.aes.Aes256.initEnc(g.key); var i: usize = 0; while (i < 3) : (i += 1) { g.incV(); var block: [16]u8 = undefined; ctx.encrypt(&block, &g.v); buf[i * 16 ..][0..16].* = block; } if (pd) |p| { for (&buf, p) |*b, x| { b.* ^= x; } } g.key = buf[0..32].*; g.v = buf[32..48].*; } // randombytes. fn fill(g: *NistDRBG, out: []u8) void { var block: [16]u8 = undefined; var dst = out; const ctx = crypto.core.aes.Aes256.initEnc(g.key); while (dst.len > 0) { g.incV(); ctx.encrypt(&block, &g.v); if (dst.len < 16) { @memcpy(dst, block[0..dst.len]); break; } dst[0..block.len].* = block; dst = dst[16..dst.len]; } g.update(null); } fn init(seed: [48]u8) NistDRBG { var ret: NistDRBG = .{ .key = .{0} ** 32, .v = .{0} ** 16 }; ret.update(seed); return ret; } }; |
Generated by zstd-live on 2025-08-10 02:45:58 UTC. |