Program Listing for File philox.hpp

Return to documentation for file (include/random/philox.hpp)

#pragma once

#include <array>
#include <bit>
#include <cstdint>
#include <limits>
#include <type_traits>

#include "macros.hpp"

namespace prng {

namespace internal {

template <std::uint8_t N, std::uint8_t W>
struct PhiloxConstants;

template <>
struct PhiloxConstants<4, 32> {
  static constexpr std::uint32_t M0 = 0xD2511F53;
  static constexpr std::uint32_t M1 = 0xCD9E8D57;
  static constexpr std::uint32_t W0 = 0x9E3779B9;
  static constexpr std::uint32_t W1 = 0xBB67AE85;
};

template <>
struct PhiloxConstants<2, 32> {
  static constexpr std::uint32_t M0 = 0xD256D193;
  static constexpr std::uint32_t W0 = 0x9E3779B9;
};

template <>
struct PhiloxConstants<4, 64> {
  static constexpr std::uint64_t M0 = 0xD2E7470EE14C6C93ULL;
  static constexpr std::uint64_t M1 = 0xCA5A826395121157ULL;
  static constexpr std::uint64_t W0 = 0x9E3779B97F4A7C15ULL;
  static constexpr std::uint64_t W1 = 0xBB67AE8584CAA73BULL;
};

template <>
struct PhiloxConstants<2, 64> {
  static constexpr std::uint64_t M0 = 0xD2B74407B1CE6E93ULL;
  static constexpr std::uint64_t W0 = 0x9E3779B97F4A7C15ULL;
};

} // namespace internal

template <std::uint8_t N = 4, std::uint8_t W = 32, std::uint8_t R = 10>
class Philox {
  static_assert(N == 2 || N == 4, "Philox N must be 2 or 4");
  static_assert(W == 32 || W == 64, "Philox W must be 32 or 64");
  static_assert(R > 0, "Philox rounds must be > 0");

public:
  using result_type = std::uint64_t;
  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;

  static constexpr auto RESULTS_PER_BLOCK = std::uint8_t{N * W / 64};
  using result_block_type = std::array<result_type, RESULTS_PER_BLOCK>;

  static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept {
    return (std::numeric_limits<result_type>::min)();
  }

  static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept {
    return (std::numeric_limits<result_type>::max)();
  }

  explicit PRNG_ALWAYS_INLINE Philox(result_type seed, result_type counter = 0) noexcept
      : m_counter(counter_from_uint64(counter)), m_key(seed_to_key(seed)) {}

  explicit PRNG_ALWAYS_INLINE Philox(key_type key, counter_type counter) noexcept
      : m_counter(counter), m_key(key) {}

  PRNG_ALWAYS_INLINE constexpr result_type operator()() noexcept {
    if (m_result_index >= RESULTS_PER_BLOCK) [[unlikely]] {
      m_result_cache = next_block();
      m_result_index = 0;
    }
    return m_result_cache[m_result_index++];
  }

  PRNG_ALWAYS_INLINE constexpr double uniform() noexcept {
    return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
  }

  counter_type getCounter() const noexcept {
    if (m_result_index < RESULTS_PER_BLOCK) {
      counter_type ctr = m_counter;
      dec_counter(ctr);
      return ctr;
    }
    return m_counter;
  }

  key_type getKey() const noexcept { return m_key; }

  void setCounter(const counter_type &ctr) noexcept {
    m_counter = ctr;
    m_result_index = RESULTS_PER_BLOCK;
  }

  void setKey(const key_type &key) noexcept {
    m_key = key;
    m_result_index = RESULTS_PER_BLOCK;
  }

  counter_type getCounterForSerde() const noexcept {
    return m_counter;
  }

  void setState(const counter_type &ctr, const key_type &key) noexcept {
    m_counter = ctr;
    m_key = key;
    m_result_index = RESULTS_PER_BLOCK;
  }

  const result_block_type &result_cache() const noexcept { return m_result_cache; }
  void set_result_cache(const result_block_type &cache) noexcept { m_result_cache = cache; }
  std::uint8_t result_index() const noexcept { return m_result_index; }
  void set_result_index(std::uint8_t idx) noexcept { m_result_index = idx; }

  static constexpr key_type seed_to_key(result_type seed) noexcept {
    key_type key{};
    auto state = seed;
    auto splitmix = [&state]() -> std::uint64_t {
      state += 0x9e3779b97f4a7c15ULL;
      auto z = state;
      z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
      z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
      return z ^ (z >> 31);
    };
    if constexpr (W == 32) {
      auto z = splitmix();
      key[0] = static_cast<word_type>(z);
      if constexpr (N == 4) {
        key[1] = static_cast<word_type>(z >> 32);
      }
    } else {
      key[0] = static_cast<word_type>(splitmix());
      if constexpr (N == 4) {
        key[1] = static_cast<word_type>(splitmix());
      }
    }
    return key;
  }

  static constexpr counter_type counter_from_uint64(result_type counter) noexcept {
    counter_type ctr{};
    if constexpr (W == 32) {
      ctr[0] = static_cast<word_type>(counter & 0xFFFFFFFF);
      if constexpr (N >= 2) {
        ctr[1] = static_cast<word_type>(counter >> 32);
      }
    } else {
      ctr[0] = static_cast<word_type>(counter);
    }
    return ctr;
  }

private:
  using C = internal::PhiloxConstants<N, W>;

  counter_type m_counter;
  key_type m_key;
  result_block_type m_result_cache{};
  std::uint8_t m_result_index = RESULTS_PER_BLOCK;

  static constexpr PRNG_ALWAYS_INLINE void mulhilo(word_type a, word_type b,
                                                     word_type &hi, word_type &lo) noexcept {
    if constexpr (W == 32) {
      auto product = static_cast<std::uint64_t>(a) * static_cast<std::uint64_t>(b);
      lo = static_cast<word_type>(product);
      hi = static_cast<word_type>(product >> 32);
    } else {
#if defined(__SIZEOF_INT128__)
      auto product = static_cast<__uint128_t>(a) * static_cast<__uint128_t>(b);
      lo = static_cast<word_type>(product);
      hi = static_cast<word_type>(product >> 64);
#else
      constexpr std::uint64_t MASK32 = 0xFFFFFFFF;
      std::uint64_t a_lo = a & MASK32, a_hi = a >> 32;
      std::uint64_t b_lo = b & MASK32, b_hi = b >> 32;
      std::uint64_t p_ll = a_lo * b_lo;
      std::uint64_t p_lh = a_lo * b_hi;
      std::uint64_t p_hl = a_hi * b_lo;
      std::uint64_t p_hh = a_hi * b_hi;
      std::uint64_t mid = (p_ll >> 32) + (p_lh & MASK32) + (p_hl & MASK32);
      lo = (p_ll & MASK32) | (mid << 32);
      hi = p_hh + (p_lh >> 32) + (p_hl >> 32) + (mid >> 32);
#endif
    }
  }

  static constexpr PRNG_ALWAYS_INLINE void single_round(counter_type &ctr, key_type &key) noexcept {
    if constexpr (N == 4) {
      word_type hi0, lo0, hi1, lo1;
      mulhilo(C::M0, ctr[0], hi0, lo0);
      mulhilo(C::M1, ctr[2], hi1, lo1);
      ctr = {hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0};
      key[0] += C::W0;
      key[1] += C::W1;
    } else {
      word_type hi, lo;
      mulhilo(C::M0, ctr[0], hi, lo);
      ctr = {hi ^ ctr[1] ^ key[0], lo};
      key[0] += C::W0;
    }
  }

  static constexpr PRNG_ALWAYS_INLINE counter_type philox_rounds(counter_type ctr, key_type key) noexcept {
    for (std::uint8_t i = 0; i < R; ++i) {
      single_round(ctr, key);
    }
    return ctr;
  }

  constexpr PRNG_ALWAYS_INLINE void inc_counter() noexcept {
    for (std::uint8_t i = 0; i < N; ++i) {
      if (++m_counter[i] != 0) break;
    }
  }

  static constexpr PRNG_ALWAYS_INLINE void dec_counter(counter_type &ctr) noexcept {
    for (std::uint8_t i = 0; i < N; ++i) {
      if (ctr[i]-- != 0) break;
    }
  }

  PRNG_FLATTEN constexpr PRNG_ALWAYS_INLINE result_block_type next_block() noexcept {
    auto output = philox_rounds(m_counter, m_key);
    inc_counter();
    return std::bit_cast<result_block_type>(output);
  }
};

using Philox4x32 = Philox<4, 32, 10>;
using Philox2x32 = Philox<2, 32, 10>;
using Philox4x64 = Philox<4, 64, 10>;
using Philox2x64 = Philox<2, 64, 10>;

} // namespace prng