Program Listing for File chacha_simd.hpp

Return to documentation for file (include/random/chacha_simd.hpp)

#pragma once

#include <array>
#include <bit>
#include <cstdint>
#include <limits>
#include <type_traits>

#include <poet/poet.hpp>

#include "dispatch_arch.hpp"
#include "macros.hpp"

namespace prng {

namespace internal {

template <class Arch, std::uint8_t R> struct ChaChaState {
  static constexpr auto MATRIX_WORDCOUNT = std::uint8_t{16};
  static constexpr auto KEY_WORDCOUNT = std::uint8_t{8};

  using input_word = std::uint64_t;
  using matrix_word = std::uint32_t;
  using matrix_type = std::array<matrix_word, MATRIX_WORDCOUNT>;
  using simd_type = xsimd::batch<matrix_word, Arch>;
  using working_state_type = std::array<simd_type, MATRIX_WORDCOUNT>;

  static constexpr std::uint8_t SIMD_WIDTH = std::uint8_t{simd_type::size};
  // Guard: SIMD_WIDTH may be 0 when an arch is instantiated in the byte storage
  // but not available at compile time. Operations are only invoked from dispatch
  // TUs compiled with the correct -march flags.
  static_assert(SIMD_WIDTH == 0 || std::has_single_bit(static_cast<unsigned int>(SIMD_WIDTH)),
                "ChaCha SIMD width must be a power of two");
  static constexpr std::uint8_t SIMD_WIDTH_SHIFT =
      SIMD_WIDTH > 0 ? static_cast<std::uint8_t>(std::countr_zero(static_cast<unsigned int>(SIMD_WIDTH))) : 0;
  static_assert(SIMD_WIDTH == 0 || MATRIX_WORDCOUNT % SIMD_WIDTH == 0,
                "ChaCha state must divide evenly into SIMD segments");
  static constexpr std::uint8_t SIMD_WIDTH_MASK = SIMD_WIDTH > 0 ? std::uint8_t(SIMD_WIDTH - 1) : 0;
  static constexpr std::uint8_t BLOCK_SEGMENTCOUNT =
      SIMD_WIDTH > 0 ? static_cast<std::uint8_t>(MATRIX_WORDCOUNT / SIMD_WIDTH) : 0;
  static constexpr std::uint8_t cache_batchcount() noexcept {
    // Use 2 batches for 512-bit+ SIMD (fewer but wider blocks per batch).
    if constexpr (simd_type::size >= 16) {
      return 2;
    } else {
      return 1;
    }
  }

  static constexpr auto CACHE_BATCHCOUNT = cache_batchcount();
  static constexpr auto CACHE_BLOCKCOUNT = std::uint8_t{CACHE_BATCHCOUNT * SIMD_WIDTH};
  using cache_block_type = std::array<simd_type, BLOCK_SEGMENTCOUNT>;
  using cache_batch_type = std::array<cache_block_type, SIMD_WIDTH>;
  static_assert(sizeof(cache_block_type) == sizeof(matrix_type),
                "Cache blocks must have the same layout size as a ChaCha block");
  static_assert(std::is_trivially_copyable_v<cache_block_type>, "Cache blocks must be trivially copyable for bit_cast");
  static_assert(std::is_trivially_copyable_v<matrix_type>, "ChaCha blocks must be trivially copyable for bit_cast");

  matrix_type m_state;
  alignas(simd_type::arch_type::alignment()) std::array<cache_batch_type, CACHE_BATCHCOUNT> m_cache;
  std::uint8_t m_cache_index = CACHE_BLOCKCOUNT;

  explicit PRNG_ALWAYS_INLINE ChaChaState(const std::array<matrix_word, KEY_WORDCOUNT> key, const input_word counter,
                                          const input_word nonce) {
    m_state[0] = 0x61707865;
    m_state[1] = 0x3320646e;
    m_state[2] = 0x79622d32;
    m_state[3] = 0x6b206574;

    poet::static_for<0, KEY_WORDCOUNT>([&](auto I) {
      m_state[4 + I] = key[I];
    });

    m_state[12] = static_cast<matrix_word>(counter & 0xFFFFFFFF);
    m_state[13] = static_cast<matrix_word>(counter >> 32);
    m_state[14] = static_cast<matrix_word>(nonce & 0xFFFFFFFF);
    m_state[15] = static_cast<matrix_word>(nonce >> 32);
  }

  PRNG_ALWAYS_INLINE matrix_type getState(bool prev) const noexcept {
    matrix_type state = m_state;
    if (m_cache_index < CACHE_BLOCKCOUNT || prev) {
      input_word counter = (static_cast<input_word>(state[13]) << 32) | static_cast<input_word>(state[12]);
      counter -= static_cast<input_word>(CACHE_BLOCKCOUNT - m_cache_index);
      if (prev) {
        --counter;
      }
      state[12] = static_cast<matrix_word>(counter & 0xFFFFFFFF);
      state[13] = static_cast<matrix_word>(counter >> 32);
    }
    return state;
  }

  PRNG_ALWAYS_INLINE matrix_type next_block() noexcept {
    if (m_cache_index >= CACHE_BLOCKCOUNT) [[unlikely]] {
      gen_next_blocks_in_cache();
      m_cache_index = 0;
    }

    const auto cache_batch = m_cache_index >> SIMD_WIDTH_SHIFT;
    const auto lane = m_cache_index & SIMD_WIDTH_MASK;
    ++m_cache_index;
    return std::bit_cast<matrix_type>(m_cache[cache_batch][lane]);
  }

private:
  static inline constexpr std::array<matrix_word, SIMD_WIDTH> LANE_OFFSETS = [] {
    std::array<matrix_word, SIMD_WIDTH> offsets{};
    poet::static_for<0, SIMD_WIDTH>([&](auto I) {
      offsets[I] = static_cast<matrix_word>(I.value);
    });
    return offsets;
  }();

  PRNG_ALWAYS_INLINE static simd_type make_higher_counter_inc(matrix_word overflow_index) noexcept {
    if (overflow_index >= SIMD_WIDTH) [[likely]] {
      return simd_type::broadcast(0);
    }

    std::array<matrix_word, SIMD_WIDTH> incs{};
    poet::static_for<1, SIMD_WIDTH>([&](auto I) {
      incs[I] = static_cast<matrix_word>(overflow_index < static_cast<matrix_word>(I.value));
    });
    return simd_type::load_unaligned(incs.data());
  }

  PRNG_ALWAYS_INLINE static void init_state_batches(working_state_type &x, const matrix_type &state,
                                                     simd_type lower_counter_inc,
                                                     simd_type higher_counter_inc) noexcept {
    poet::static_for<0, MATRIX_WORDCOUNT>([&](auto I) {
      x[I] = simd_type::broadcast(state[I]);
    });
    x[12] += lower_counter_inc;
    x[13] += higher_counter_inc;
  }

  PRNG_ALWAYS_INLINE static void add_original_state(working_state_type &x, const matrix_type &state,
                                                     simd_type lower_counter_inc,
                                                     simd_type higher_counter_inc) noexcept {
    poet::static_for<0, MATRIX_WORDCOUNT>([&](auto I) {
      x[I] += simd_type::broadcast(state[I]);
    });
    x[12] += lower_counter_inc;
    x[13] += higher_counter_inc;
  }

  static void transpose_into_cache(cache_batch_type &cache, working_state_type &x) noexcept {
    auto *PRNG_RESTRICT cache_lanes = cache.data();
    auto *PRNG_RESTRICT working = x.data();
    poet::static_for<0, BLOCK_SEGMENTCOUNT>([&](auto Seg) {
      auto *PRNG_RESTRICT segment_begin = working + Seg * SIMD_WIDTH;
      xsimd::transpose(segment_begin, segment_begin + SIMD_WIDTH);
      poet::static_for<0, SIMD_WIDTH>([&](auto Lane) {
        cache_lanes[Lane][Seg] = segment_begin[Lane];
      });
    });
  }

  PRNG_ALWAYS_INLINE static constexpr void advance_counter(matrix_type &state) noexcept {
    state[12] += SIMD_WIDTH;
    state[13] += state[12] < SIMD_WIDTH;
  }

  template <unsigned A, unsigned B, unsigned C, unsigned D>
  PRNG_ALWAYS_INLINE static void quarter_round(working_state_type &x) noexcept {
    x[A] += x[B]; x[D] ^= x[A]; x[D] = xsimd::rotl<16>(x[D]);
    x[C] += x[D]; x[B] ^= x[C]; x[B] = xsimd::rotl<12>(x[B]);
    x[A] += x[B]; x[D] ^= x[A]; x[D] = xsimd::rotl<8>(x[D]);
    x[C] += x[D]; x[B] ^= x[C]; x[B] = xsimd::rotl<7>(x[B]);
  }

  PRNG_ALWAYS_INLINE static void gen_block_batch(cache_batch_type &cache, const matrix_type &state) noexcept {
    const simd_type lower_counter_inc = simd_type::load_unaligned(LANE_OFFSETS.data());
    matrix_word overflow_index = std::numeric_limits<matrix_word>::max() - state[12];
    const simd_type higher_counter_inc = make_higher_counter_inc(overflow_index);

    working_state_type x;
    init_state_batches(x, state, lower_counter_inc, higher_counter_inc);

    poet::static_for<0, R / 2>([&](auto) {
      // Column round: QR(i, i+4, i+8, i+12)
      poet::static_for<0, 4>([&](auto I) {
        constexpr auto i = static_cast<unsigned>(I.value);
        quarter_round<i, i + 4, i + 8, i + 12>(x);
      });
      // Diagonal round: QR(i, ((i+1)%4)+4, ((i+2)%4)+8, ((i+3)%4)+12)
      poet::static_for<0, 4>([&](auto I) {
        constexpr auto i = static_cast<unsigned>(I.value);
        quarter_round<i, ((i + 1) % 4) + 4, ((i + 2) % 4) + 8, ((i + 3) % 4) + 12>(x);
      });
    });

    add_original_state(x, state, lower_counter_inc, higher_counter_inc);
    transpose_into_cache(cache, x);
  }

  PRNG_ALWAYS_INLINE constexpr void gen_next_blocks_in_cache() noexcept {
    auto state = m_state;
    poet::static_for<0, CACHE_BATCHCOUNT>([&](auto Batch) {
      gen_block_batch(m_cache[Batch], state);
      advance_counter(state);
    });
    m_state = state;
  }
};

struct ChaChaSIMDInitResult {
  using matrix_type = std::array<std::uint32_t, 16>;
  using next_block_fn = matrix_type (*)(void *) noexcept;
  using get_state_fn = matrix_type (*)(const void *, bool) noexcept;
  using set_state_fn = void (*)(void *, const matrix_type &) noexcept;
  using get_cache_index_fn = std::uint8_t (*)(const void *) noexcept;
  next_block_fn next_block;
  get_state_fn get_state;
  set_state_fn set_state;
  get_cache_index_fn get_cache_index;
  std::size_t simd_size;
};

template <std::uint8_t R> struct ChaChaSIMDInitFunctor {
  void *state_storage;
  const std::array<std::uint32_t, 8> key;
  const std::uint64_t counter, nonce;

  template <class Arch> ChaChaSIMDInitResult operator()(Arch) const noexcept;
};

template <std::uint8_t R>
template <class Arch>
ChaChaSIMDInitResult ChaChaSIMDInitFunctor<R>::operator()(Arch) const noexcept {
  using State = ChaChaState<Arch, R>;
  static_assert(sizeof(State) <= 2176, "ChaChaState exceeds StateStorage capacity");
  static_assert(alignof(State) <= 64, "ChaChaState exceeds StateStorage alignment");
  new (state_storage) State(key, counter, nonce);
  return {
      +[](void *s) noexcept -> ChaChaSIMDInitResult::matrix_type {
        return static_cast<State *>(s)->next_block();
      },
      +[](const void *s, bool prev) noexcept -> ChaChaSIMDInitResult::matrix_type {
        return static_cast<const State *>(s)->getState(prev);
      },
      +[](void *s, const ChaChaSIMDInitResult::matrix_type &matrix) noexcept {
        auto *state = static_cast<State *>(s);
        state->m_state = matrix;
        state->m_cache_index = State::CACHE_BLOCKCOUNT;
      },
      +[](const void *s) noexcept -> std::uint8_t {
        return static_cast<const State *>(s)->m_cache_index;
      },
      std::size_t{State::SIMD_WIDTH},
  };
}

#define PRNG_CHACHA_EXTERN_TEMPLATE(R, Arch)                                                                           \
  extern template PRNG_EXPORT ChaChaSIMDInitResult ChaChaSIMDInitFunctor<R>::operator()<Arch>(Arch) const noexcept

#if PRNG_ARCH_X86_64
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::sse2);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::sse2);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::sse2);
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::avx2);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::avx2);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::avx2);
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::avx512f);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::avx512f);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::avx512f);
#elif PRNG_ARCH_AARCH64
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::neon64);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::neon64);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::neon64);
#  if XSIMD_WITH_SVE
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::sve);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::sve);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::sve);
#  endif
#elif PRNG_ARCH_RISCV64
PRNG_CHACHA_EXTERN_TEMPLATE(8, xsimd::detail::rvv<128>);
PRNG_CHACHA_EXTERN_TEMPLATE(12, xsimd::detail::rvv<128>);
PRNG_CHACHA_EXTERN_TEMPLATE(20, xsimd::detail::rvv<128>);
#endif

#undef PRNG_CHACHA_EXTERN_TEMPLATE

} // namespace internal

template <std::uint8_t R = 20> class ChaChaSIMD {
public:
  static constexpr auto MATRIX_WORDCOUNT = std::uint8_t{16};
  static constexpr auto KEY_WORDCOUNT = std::uint8_t{8};

  using result_type = std::uint64_t;
  using input_word = std::uint64_t;
  using matrix_word = std::uint32_t;
  using matrix_type = std::array<matrix_word, MATRIX_WORDCOUNT>;
  using result_cache_type = std::array<result_type, MATRIX_WORDCOUNT / 2>;

  static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept { return (std::numeric_limits<result_type>::min)(); }
  static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept { return (std::numeric_limits<result_type>::max)(); }

  static constexpr PRNG_ALWAYS_INLINE matrix_type results_to_block(const result_cache_type &results) noexcept {
    return std::bit_cast<matrix_type>(results);
  }

  static constexpr PRNG_ALWAYS_INLINE result_cache_type block_to_results(const matrix_type &block) noexcept {
    return std::bit_cast<result_cache_type>(block);
  }

  static constexpr std::array<matrix_word, KEY_WORDCOUNT> seed_to_key(result_type seed) noexcept {
    std::array<matrix_word, KEY_WORDCOUNT> key{};
    // SplitMix64 expansion: 1 uint64 seed -> 4 uint64 -> 8 uint32
    auto state = seed;
    for (std::uint8_t i = 0; i < 4; ++i) {
      state += 0x9e3779b97f4a7c15ULL;
      auto z = state;
      z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
      z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
      z = z ^ (z >> 31);
      key[i * 2] = static_cast<matrix_word>(z);
      key[i * 2 + 1] = static_cast<matrix_word>(z >> 32);
    }
    return key;
  }

  explicit PRNG_ALWAYS_INLINE ChaChaSIMD(result_type seed, const input_word counter = 0, const input_word nonce = 0)
      : ChaChaSIMD(seed_to_key(seed), counter, nonce) {}

  explicit PRNG_ALWAYS_INLINE ChaChaSIMD(const std::array<matrix_word, KEY_WORDCOUNT> key, const input_word counter,
                                          const input_word nonce) {
    auto result =
        xsimd::dispatch<dispatch_arch_list>(
            internal::ChaChaSIMDInitFunctor<R>{m_state.data, key, counter, nonce})();
    m_next_block = result.next_block;
    m_get_state = result.get_state;
    m_set_state = result.set_state;
    m_get_cache_index = result.get_cache_index;
    m_simd_size = result.simd_size;
  }

  PRNG_ALWAYS_INLINE constexpr result_type operator()() noexcept {
    if (m_result_index >= m_result_cache.size()) [[unlikely]] {
      m_result_cache = block_to_results(m_next_block(m_state.data));
      m_result_index = 0;
    }
    return m_result_cache[m_result_index++];
  }

  PRNG_ALWAYS_INLINE constexpr double uniform() noexcept {
    return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
  }

  PRNG_ALWAYS_INLINE constexpr matrix_type block() noexcept {
    if (m_result_index < m_result_cache.size()) {
      auto cached_block = results_to_block(m_result_cache);
      m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
      return cached_block;
    }
    return m_next_block(m_state.data);
  }

  PRNG_ALWAYS_INLINE constexpr matrix_type getState() const noexcept {
    return m_get_state(m_state.data, m_result_index < m_result_cache.size());
  }

  matrix_type getStateForSerde() const noexcept {
    return m_get_state(m_state.data, false);
  }

  void setState(const matrix_type &matrix) noexcept {
    m_set_state(m_state.data, matrix);
  }

  const result_cache_type &result_cache() const noexcept { return m_result_cache; }
  void set_result_cache(const result_cache_type &cache) noexcept { m_result_cache = cache; }
  std::uint8_t result_index() const noexcept { return m_result_index; }
  void set_result_index(std::uint8_t idx) noexcept { m_result_index = idx; }

  PRNG_ALWAYS_INLINE size_t getSIMDSize() const noexcept { return m_simd_size; }

private:
  using next_block_fn = internal::ChaChaSIMDInitResult::next_block_fn;
  using get_state_fn = internal::ChaChaSIMDInitResult::get_state_fn;
  using set_state_fn = internal::ChaChaSIMDInitResult::set_state_fn;
  using get_cache_index_fn = internal::ChaChaSIMDInitResult::get_cache_index_fn;

  // Raw byte storage for the arch-specific ChaChaState.
  // Typed union is not viable: xsimd batch types have different sizeof
  // across TUs compiled with different -march flags (ODR divergence).
  // Max is avx512f: 64B state + 2048B cache + 1B index + padding = 2176 bytes.
  struct StateStorage {
    static constexpr std::size_t SIZE = 2176;
    static constexpr std::size_t ALIGN = 64;
    alignas(ALIGN) unsigned char data[SIZE];
  };

  alignas(64) StateStorage m_state;
  next_block_fn m_next_block = nullptr;
  get_state_fn m_get_state = nullptr;
  set_state_fn m_set_state = nullptr;
  get_cache_index_fn m_get_cache_index = nullptr;
  std::size_t m_simd_size = 0;
  result_cache_type m_result_cache{};
  std::uint8_t m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
};

#ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE
template <std::uint8_t R = 20> class ChaChaNative {
public:
  static constexpr auto MATRIX_WORDCOUNT = std::uint8_t{16};
  static constexpr auto KEY_WORDCOUNT = std::uint8_t{8};

  using result_type = std::uint64_t;
  using input_word = std::uint64_t;
  using matrix_word = std::uint32_t;
  using matrix_type = std::array<matrix_word, MATRIX_WORDCOUNT>;
  using result_cache_type = std::array<result_type, MATRIX_WORDCOUNT / 2>;

  static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept { return (std::numeric_limits<result_type>::min)(); }
  static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept { return (std::numeric_limits<result_type>::max)(); }

  explicit ChaChaNative(result_type seed, const input_word counter = 0, const input_word nonce = 0)
      : ChaChaNative(ChaChaSIMD<R>::seed_to_key(seed), counter, nonce) {}

  ChaChaNative(const std::array<matrix_word, KEY_WORDCOUNT> key, const input_word counter, const input_word nonce)
      : m_state(key, counter, nonce) {}

  PRNG_ALWAYS_INLINE constexpr result_type operator()() noexcept {
    if (m_result_index >= m_result_cache.size()) [[unlikely]] {
      m_result_cache = ChaChaSIMD<R>::block_to_results(m_state.next_block());
      m_result_index = 0;
    }
    return m_result_cache[m_result_index++];
  }

  PRNG_ALWAYS_INLINE constexpr double uniform() noexcept {
    return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
  }

  PRNG_ALWAYS_INLINE constexpr matrix_type block() noexcept {
    if (m_result_index < m_result_cache.size()) {
      auto cached_block = ChaChaSIMD<R>::results_to_block(m_result_cache);
      m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
      return cached_block;
    }
    return m_state.next_block();
  }

  PRNG_ALWAYS_INLINE constexpr matrix_type getState() const noexcept {
    return m_state.getState(m_result_index < m_result_cache.size());
  }

  matrix_type getStateForSerde() const noexcept {
    return m_state.getState(false);
  }

  void setState(const matrix_type &matrix) noexcept {
    m_state.m_state = matrix;
    m_state.m_cache_index = State::CACHE_BLOCKCOUNT;
  }

  const result_cache_type &result_cache() const noexcept { return m_result_cache; }
  void set_result_cache(const result_cache_type &cache) noexcept { m_result_cache = cache; }
  std::uint8_t result_index() const noexcept { return m_result_index; }
  void set_result_index(std::uint8_t idx) noexcept { m_result_index = idx; }

  PRNG_ALWAYS_INLINE size_t getSIMDSize() const noexcept { return std::size_t{State::SIMD_WIDTH}; }

private:
  using State = internal::ChaChaState<xsimd::best_arch, R>;
  State m_state;
  result_cache_type m_result_cache{};
  std::uint8_t m_result_index = static_cast<std::uint8_t>(m_result_cache.size());
};
#endif // XSIMD_NO_SUPPORTED_ARCHITECTURE

// Convenience aliases for common ChaCha variants.
using ChaCha8SIMD = ChaChaSIMD<8>;
using ChaCha12SIMD = ChaChaSIMD<12>;
using ChaCha20SIMD = ChaChaSIMD<20>;

#ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE
using ChaCha8Native = ChaChaNative<8>;
using ChaCha12Native = ChaChaNative<12>;
using ChaCha20Native = ChaChaNative<20>;
#endif // XSIMD_NO_SUPPORTED_ARCHITECTURE

} // namespace prng