Program Listing for File chacha.hpp
↰ Return to documentation for file (include/random/chacha.hpp)
#pragma once
#include <array>
#include <bit>
#include <cstdint>
#include <limits>
#include "macros.hpp"
namespace prng {
template<std::uint8_t R = 20>
class ChaCha {
protected:
static constexpr auto MATRIX_WORDCOUNT = std::uint8_t{16};
static constexpr auto KEY_WORDCOUNT = std::uint8_t{8};
public:
using result_type = std::uint64_t;
using input_word = std::uint64_t;
using matrix_word = std::uint32_t;
using matrix_type = std::array<matrix_word, MATRIX_WORDCOUNT>;
using result_cache_type = std::array<result_type, MATRIX_WORDCOUNT / 2>;
static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept {
return (std::numeric_limits<result_type>::min)();
}
static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept {
return (std::numeric_limits<result_type>::max)();
}
PRNG_ALWAYS_INLINE explicit ChaCha(
const std::array<matrix_word, KEY_WORDCOUNT> key,
const input_word counter,
const input_word nonce
) noexcept {
// First four words (i.e. top-row) are always the same constants
// They spell out "expand 2-byte k" in ASCII (little-endian)
m_state[0] = 0x61707865;
m_state[1] = 0x3320646e;
m_state[2] = 0x79622d32;
m_state[3] = 0x6b206574;
for (auto i = 0; i < 8; ++i) {
m_state[4 + i] = key[i];
}
// ChaCha assumes little-endianness
m_state[12] = static_cast<matrix_word>(counter & 0xFFFFFFFF);
m_state[13] = static_cast<matrix_word>(counter >> 32);
m_state[14] = static_cast<matrix_word>(nonce & 0xFFFFFFFF);
m_state[15] = static_cast<matrix_word>(nonce >> 32);
}
PRNG_ALWAYS_INLINE constexpr result_type(operator())() noexcept { return next_result(); }
PRNG_ALWAYS_INLINE constexpr double uniform() noexcept {
return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
}
PRNG_ALWAYS_INLINE constexpr matrix_type block() noexcept {
if (m_result_index < m_result_cache.size()) {
auto cached_block = results_to_block(m_result_cache);
m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
return cached_block;
}
return next_block();
}
PRNG_ALWAYS_INLINE constexpr matrix_type getState() const noexcept {
matrix_type state = m_state;
if (m_result_index < m_result_cache.size()) {
const input_word counter =
(static_cast<input_word>(state[13]) << 32) | static_cast<input_word>(state[12]);
const input_word current_counter = counter - 1;
state[12] = static_cast<matrix_word>(current_counter & 0xFFFFFFFF);
state[13] = static_cast<matrix_word>(current_counter >> 32);
}
return state;
}
private:
matrix_type m_state;
result_cache_type m_result_cache{};
std::uint8_t m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
static constexpr PRNG_ALWAYS_INLINE auto rotl(const matrix_word x, const int k) noexcept {
return std::rotl(x, k);
}
static constexpr PRNG_ALWAYS_INLINE void quarter_round(
matrix_type &m,
const unsigned int a,
const unsigned int b,
const unsigned int c,
const unsigned int d
) noexcept {
m[a] += m[b]; m[d] ^= m[a]; m[d] = rotl(m[d], 16);
m[c] += m[d]; m[b] ^= m[c]; m[b] = rotl(m[b], 12);
m[a] += m[b]; m[d] ^= m[a]; m[d] = rotl(m[d], 8);
m[c] += m[d]; m[b] ^= m[c]; m[b] = rotl(m[b], 7);
}
constexpr PRNG_ALWAYS_INLINE void inc_counter() noexcept {
if (++m_state[12] == 0) {
++m_state[13];
}
}
static constexpr PRNG_ALWAYS_INLINE result_cache_type block_to_results(const matrix_type& block) noexcept {
return std::bit_cast<result_cache_type>(block);
}
static constexpr PRNG_ALWAYS_INLINE matrix_type results_to_block(const result_cache_type& results) noexcept {
return std::bit_cast<matrix_type>(results);
}
constexpr PRNG_ALWAYS_INLINE result_type next_result() noexcept {
if (m_result_index >= m_result_cache.size()) [[unlikely]] {
m_result_cache = block_to_results(next_block());
m_result_index = 0;
}
return m_result_cache[m_result_index++];
}
PRNG_FLATTEN constexpr PRNG_ALWAYS_INLINE matrix_type next_block() noexcept {
matrix_type x = m_state;
// Note that we perform both an odd and even round at the same time.
// As a result the amount of rounds performed is always rounded up to an even number.
for (auto i = 0; i < R; i += 2) {
// Odd round
quarter_round(x, 0, 4, 8,12);
quarter_round(x, 1, 5, 9,13);
quarter_round(x, 2, 6,10,14);
quarter_round(x, 3, 7,11,15);
// Even round
quarter_round(x, 0, 5,10,15);
quarter_round(x, 1, 6,11,12);
quarter_round(x, 2, 7, 8,13);
quarter_round(x, 3, 4, 9,14);
}
for (auto i = 0; i < MATRIX_WORDCOUNT; ++i) {
x[i] += m_state[i];
}
inc_counter();
return x;
}
};
}