Program Listing for File philox_simd.hpp

Return to documentation for file (include/random/philox_simd.hpp)

#pragma once

#include <array>
#include <bit>
#include <cstdint>
#include <limits>
#include <type_traits>

#include <poet/poet.hpp>

#include "dispatch_arch.hpp"
#include "macros.hpp"
#include "philox.hpp"

namespace prng {

namespace internal {

template <class Arch, std::uint8_t N, std::uint8_t W, std::uint8_t R>
struct PhiloxState {
  static_assert(N == 2 || N == 4, "Philox N must be 2 or 4");
  static_assert(W == 32 || W == 64, "Philox W must be 32 or 64");

  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;
  using result_type = std::uint64_t;
  static constexpr auto RESULTS_PER_BLOCK = std::uint8_t{N * W / 64};

  using simd_type = xsimd::batch<word_type, Arch>;
  static constexpr std::uint8_t SIMD_WIDTH = std::uint8_t{simd_type::size};
  static_assert(SIMD_WIDTH == 0 || std::has_single_bit(static_cast<unsigned int>(SIMD_WIDTH)),
                "Philox SIMD width must be a power of two");

  static constexpr std::uint16_t CACHE_SIZE = 256;
  static constexpr std::uint16_t BLOCKS_PER_CACHE = CACHE_SIZE / RESULTS_PER_BLOCK;
  static_assert(CACHE_SIZE % RESULTS_PER_BLOCK == 0,
                "CACHE_SIZE must be a multiple of RESULTS_PER_BLOCK");
  static constexpr std::uint16_t BATCHES_PER_CACHE =
      SIMD_WIDTH == 0 ? 1 : BLOCKS_PER_CACHE / SIMD_WIDTH;
  static_assert(SIMD_WIDTH == 0 || BLOCKS_PER_CACHE % SIMD_WIDTH == 0,
                "BLOCKS_PER_CACHE must be a multiple of SIMD_WIDTH");

  counter_type m_counter;
  key_type m_key;

  explicit PRNG_ALWAYS_INLINE PhiloxState(const key_type &key, const counter_type &counter) noexcept
      : m_counter(counter), m_key(key) {}

  PRNG_ALWAYS_INLINE void populate_cache(std::array<result_type, CACHE_SIZE> &cache) noexcept {
    auto counter = m_counter;
    poet::static_for<0, BATCHES_PER_CACHE>([&](auto I) {
      gen_block_batch(cache.data() + I.value * SIMD_WIDTH * RESULTS_PER_BLOCK,
                      counter, m_key);
      advance_counter(counter, static_cast<word_type>(SIMD_WIDTH));
    });
    m_counter = counter;
  }

  counter_type getCounter(bool prev, std::uint16_t cache_index) const noexcept {
    counter_type ctr = m_counter;
    auto consumed = static_cast<word_type>(cache_index / RESULTS_PER_BLOCK);
    auto back_off = static_cast<word_type>(BLOCKS_PER_CACHE) - consumed;
    if (prev) ++back_off;
    // Two's-complement: subtract back_off via advance_counter(-back_off).
    advance_counter(ctr, static_cast<word_type>(0) - back_off);
    return ctr;
  }

  counter_type getRawCounter() const noexcept { return m_counter; }

  key_type getKey() const noexcept { return m_key; }

  void setState(const counter_type &ctr, const key_type &key) noexcept {
    m_counter = ctr;
    m_key = key;
  }

private:
  using C = PhiloxConstants<N, W>;

  static PRNG_ALWAYS_INLINE void advance_counter(counter_type &ctr, word_type amount) noexcept {
    word_type old = ctr[0];
    ctr[0] += amount;
    if constexpr (N >= 2) {
      if (ctr[0] < old) { // carry
        for (std::uint8_t i = 1; i < N; ++i) {
          if (++ctr[i] != 0) break;
        }
      }
    }
  }

  static PRNG_ALWAYS_INLINE void mulhilo_simd(simd_type a, word_type B,
                                                simd_type &hi, simd_type &lo) noexcept {
    auto hilo = xsimd::mulhilo(a, simd_type::broadcast(B));
    hi = hilo.first;
    lo = hilo.second;
  }

  static PRNG_ALWAYS_INLINE void simd_single_round(std::array<simd_type, N> &ctr,
                                                     key_type &key) noexcept {
    if constexpr (N == 4) {
      simd_type hi0, lo0, hi1, lo1;
      mulhilo_simd(ctr[0], C::M0, hi0, lo0);
      mulhilo_simd(ctr[2], C::M1, hi1, lo1);
      auto k0 = simd_type::broadcast(key[0]);
      auto k1 = simd_type::broadcast(key[1]);
      ctr[0] = hi1 ^ ctr[1] ^ k0;
      ctr[1] = lo1;
      ctr[2] = hi0 ^ ctr[3] ^ k1;
      ctr[3] = lo0;
      key[0] += C::W0;
      key[1] += C::W1;
    } else {
      simd_type hi, lo;
      mulhilo_simd(ctr[0], C::M0, hi, lo);
      auto k0 = simd_type::broadcast(key[0]);
      ctr[0] = hi ^ ctr[1] ^ k0;
      ctr[1] = lo;
      key[0] += C::W0;
    }
  }

  static PRNG_ALWAYS_INLINE void init_counter_batch(
      std::array<simd_type, N> &ctr_simd,
      const counter_type &counter) noexcept {
    alignas(simd_type::arch_type::alignment()) std::array<word_type, SIMD_WIDTH> offsets{};
    poet::static_for<0, SIMD_WIDTH>([&](auto I) {
      offsets[I] = static_cast<word_type>(I.value);
    });

    auto carry = simd_type::load_aligned(offsets.data());
    poet::static_for<0, N>([&](auto I) {
      auto base = simd_type::broadcast(counter[I]);
      auto val = base + carry;
      ctr_simd[I] = val;
      if constexpr (I + 1 < N) {
        carry = xsimd::select(val < base,
                              simd_type::broadcast(word_type{1}),
                              simd_type::broadcast(word_type{0}));
      }
    });
  }

  static PRNG_ALWAYS_INLINE void store_blocks_to_cache(result_type *cache,
                                     const std::array<simd_type, N> &ctr_simd) noexcept {
    if constexpr (W == 64 && N == SIMD_WIDTH) {
      std::array<simd_type, SIMD_WIDTH> regs;
      poet::static_for<0, N>([&](auto I) { regs[I] = ctr_simd[I]; });
      xsimd::transpose(regs.data(), regs.data() + SIMD_WIDTH);
      poet::static_for<0, SIMD_WIDTH>([&](auto Lane) {
        regs[Lane].store_aligned(cache + Lane * RESULTS_PER_BLOCK);
      });
    } else {
      alignas(simd_type::arch_type::alignment()) std::array<word_type, SIMD_WIDTH> regs[N];
      poet::static_for<0, N>([&](auto I) {
        ctr_simd[I].store_aligned(regs[I].data());
      });

      for (std::uint8_t lane = 0; lane < SIMD_WIDTH; ++lane) {
        if constexpr (W == 32) {
          for (std::uint8_t k = 0; k < RESULTS_PER_BLOCK; ++k) {
            auto lo32 = static_cast<std::uint64_t>(regs[2 * k][lane]);
            auto hi32 = static_cast<std::uint64_t>(regs[2 * k + 1][lane]);
            cache[lane * RESULTS_PER_BLOCK + k] = lo32 | (hi32 << 32);
          }
        } else {
          for (std::uint8_t k = 0; k < N; ++k) {
            cache[lane * RESULTS_PER_BLOCK + k] = regs[k][lane];
          }
        }
      }
    }
  }

  static PRNG_ALWAYS_INLINE void gen_block_batch(result_type *cache,
                                                  const counter_type &counter,
                                                  const key_type &key) noexcept {
    std::array<simd_type, N> ctr_simd;
    init_counter_batch(ctr_simd, counter);

    key_type round_key = key;
    poet::static_for<0, R>([&](auto) {
      simd_single_round(ctr_simd, round_key);
    });

    store_blocks_to_cache(cache, ctr_simd);
  }
};

// Dispatch result: function pointers for the type-erased PhiloxSIMD wrapper.
template <std::uint8_t N, std::uint8_t W>
struct PhiloxSIMDInitResult {
  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;
  using result_type = std::uint64_t;
  static constexpr std::uint16_t CACHE_SIZE = 256;

  using populate_fn = void (*)(void *, std::array<result_type, CACHE_SIZE> &) noexcept;
  using get_counter_fn = counter_type (*)(const void *, bool, std::uint16_t) noexcept;
  using get_raw_counter_fn = counter_type (*)(const void *) noexcept;
  using get_key_fn = key_type (*)(const void *) noexcept;
  using set_state_fn = void (*)(void *, const counter_type &, const key_type &) noexcept;

  populate_fn populate_cache;
  get_counter_fn get_counter;
  get_raw_counter_fn get_raw_counter;
  get_key_fn get_key;
  set_state_fn set_state;
  std::size_t simd_size;
};

template <std::uint8_t N, std::uint8_t W, std::uint8_t R>
struct PhiloxSIMDInitFunctor {
  void *state_storage;
  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;
  const key_type key;
  const counter_type counter;

  template <class Arch>
  PhiloxSIMDInitResult<N, W> operator()(Arch) const noexcept;
};

template <std::uint8_t N, std::uint8_t W, std::uint8_t R>
template <class Arch>
PhiloxSIMDInitResult<N, W> PhiloxSIMDInitFunctor<N, W, R>::operator()(Arch) const noexcept {
  using State = PhiloxState<Arch, N, W, R>;
  using InitResult = PhiloxSIMDInitResult<N, W>;
  static_assert(sizeof(State) <= 256, "PhiloxState exceeds StateStorage capacity");
  static_assert(alignof(State) <= 64, "PhiloxState exceeds StateStorage alignment");
  new (state_storage) State(key, counter);
  return {
      +[](void *s, std::array<typename InitResult::result_type, InitResult::CACHE_SIZE> &cache) noexcept {
        static_cast<State *>(s)->populate_cache(cache);
      },
      +[](const void *s, bool prev, std::uint16_t idx) noexcept -> typename InitResult::counter_type {
        return static_cast<const State *>(s)->getCounter(prev, idx);
      },
      +[](const void *s) noexcept -> typename InitResult::counter_type {
        return static_cast<const State *>(s)->getRawCounter();
      },
      +[](const void *s) noexcept -> typename InitResult::key_type {
        return static_cast<const State *>(s)->getKey();
      },
      +[](void *s, const typename InitResult::counter_type &ctr,
          const typename InitResult::key_type &key) noexcept {
        static_cast<State *>(s)->setState(ctr, key);
      },
      std::size_t{State::SIMD_WIDTH},
  };
}

// Extern template declarations for all NxW combos and architectures
#define PRNG_PHILOX_EXTERN_TEMPLATE(N, W, R, Arch)                                                  \
  extern template PRNG_EXPORT PhiloxSIMDInitResult<N, W>                                            \
  PhiloxSIMDInitFunctor<N, W, R>::operator()<Arch>(Arch) const noexcept

#define PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(Arch)                                                 \
  PRNG_PHILOX_EXTERN_TEMPLATE(4, 32, 10, Arch);                                                    \
  PRNG_PHILOX_EXTERN_TEMPLATE(2, 32, 10, Arch);                                                    \
  PRNG_PHILOX_EXTERN_TEMPLATE(4, 64, 10, Arch);                                                    \
  PRNG_PHILOX_EXTERN_TEMPLATE(2, 64, 10, Arch)

#if PRNG_ARCH_X86_64
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::sse2);
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::avx2);
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::avx512f);
#elif PRNG_ARCH_AARCH64
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::neon64);
#  if XSIMD_WITH_SVE
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::sve);
#  endif
#elif PRNG_ARCH_RISCV64
PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH(xsimd::detail::rvv<128>);
#endif

#undef PRNG_PHILOX_EXTERN_TEMPLATES_FOR_ARCH
#undef PRNG_PHILOX_EXTERN_TEMPLATE

} // namespace internal

// PhiloxSIMD: runtime SIMD dispatch via inline storage + function pointers.
template <std::uint8_t N = 4, std::uint8_t W = 32, std::uint8_t R = 10>
class PhiloxSIMD {
public:
  using result_type = std::uint64_t;
  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;
  static constexpr std::uint16_t CACHE_SIZE = 256;

  static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept { return (std::numeric_limits<result_type>::min)(); }
  static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept { return (std::numeric_limits<result_type>::max)(); }

  static constexpr key_type seed_to_key(result_type seed) noexcept {
    return Philox<N, W, R>::seed_to_key(seed);
  }

  static constexpr counter_type counter_from_uint64(result_type counter) noexcept {
    return Philox<N, W, R>::counter_from_uint64(counter);
  }

  explicit PRNG_ALWAYS_INLINE PhiloxSIMD(result_type seed, result_type counter = 0) noexcept
      : PhiloxSIMD(seed_to_key(seed), counter_from_uint64(counter)) {}

  explicit PRNG_ALWAYS_INLINE PhiloxSIMD(key_type key, counter_type counter) noexcept {
    auto result = xsimd::dispatch<dispatch_arch_list>(
        internal::PhiloxSIMDInitFunctor<N, W, R>{m_state.data, key, counter})();
    m_populate_cache = result.populate_cache;
    m_get_counter = result.get_counter;
    m_get_raw_counter = result.get_raw_counter;
    m_get_key = result.get_key;
    m_set_state = result.set_state;
    m_simd_size = result.simd_size;
  }

  PRNG_ALWAYS_INLINE result_type operator()() noexcept {
    if (m_index == 0) [[unlikely]] {
      m_populate_cache(m_state.data, m_cache);
    }
    return m_cache[m_index++];
  }

  PRNG_ALWAYS_INLINE double uniform() noexcept {
    return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
  }

  counter_type getCounter() const noexcept {
    return m_get_counter(m_state.data, m_index != 0, m_index);
  }

  key_type getKey() const noexcept { return m_get_key(m_state.data); }

  counter_type getCounterForSerde() const noexcept {
    return m_get_raw_counter(m_state.data);
  }

  void setState(const counter_type &ctr, const key_type &key) noexcept {
    m_set_state(m_state.data, ctr, key);
    m_index = 0;
  }

  std::uint8_t cache_index() const noexcept { return m_index; }
  void set_cache_index(std::uint8_t idx) noexcept { m_index = idx; }
  const std::array<result_type, CACHE_SIZE> &cache() const noexcept { return m_cache; }
  std::array<result_type, CACHE_SIZE> &cache() noexcept { return m_cache; }

  PRNG_ALWAYS_INLINE std::size_t getSIMDSize() const noexcept { return m_simd_size; }

private:
  using InitResult = internal::PhiloxSIMDInitResult<N, W>;
  using populate_fn = typename InitResult::populate_fn;
  using get_counter_fn = typename InitResult::get_counter_fn;
  using get_raw_counter_fn = typename InitResult::get_raw_counter_fn;
  using get_key_fn = typename InitResult::get_key_fn;
  using set_state_fn = typename InitResult::set_state_fn;

  struct StateStorage {
    static constexpr std::size_t SIZE = 256;
    static constexpr std::size_t ALIGN = 64;
    alignas(ALIGN) unsigned char data[SIZE];
  };

  alignas(64) std::array<result_type, CACHE_SIZE> m_cache{};
  alignas(64) StateStorage m_state;
  populate_fn m_populate_cache = nullptr;
  get_counter_fn m_get_counter = nullptr;
  get_raw_counter_fn m_get_raw_counter = nullptr;
  get_key_fn m_get_key = nullptr;
  set_state_fn m_set_state = nullptr;
  std::size_t m_simd_size = 0;
  std::uint8_t m_index = 0;
};

#ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE
template <std::uint8_t N = 4, std::uint8_t W = 32, std::uint8_t R = 10>
class PhiloxNative {
public:
  using result_type = std::uint64_t;
  using word_type = std::conditional_t<W == 32, std::uint32_t, std::uint64_t>;
  using counter_type = std::array<word_type, N>;
  using key_type = std::array<word_type, N / 2>;
  static constexpr std::uint16_t CACHE_SIZE = 256;

  static constexpr PRNG_ALWAYS_INLINE auto(min)() noexcept { return (std::numeric_limits<result_type>::min)(); }
  static constexpr PRNG_ALWAYS_INLINE auto(max)() noexcept { return (std::numeric_limits<result_type>::max)(); }

  explicit PhiloxNative(result_type seed, result_type counter = 0) noexcept
      : PhiloxNative(Philox<N, W, R>::seed_to_key(seed), Philox<N, W, R>::counter_from_uint64(counter)) {}

  PhiloxNative(key_type key, counter_type counter) noexcept
      : m_state(key, counter) {}

  PRNG_ALWAYS_INLINE result_type operator()() noexcept {
    if (m_index == 0) [[unlikely]] {
      m_state.populate_cache(m_cache);
    }
    return m_cache[m_index++];
  }

  PRNG_ALWAYS_INLINE double uniform() noexcept {
    return static_cast<double>(operator()() >> 11) * 0x1.0p-53;
  }

  counter_type getCounter() const noexcept {
    return m_state.getCounter(m_index != 0, m_index);
  }

  key_type getKey() const noexcept { return m_state.getKey(); }

  counter_type getCounterForSerde() const noexcept {
    return m_state.getRawCounter();
  }

  void setState(const counter_type &ctr, const key_type &key) noexcept {
    m_state.setState(ctr, key);
    m_index = 0;
  }

  std::uint8_t cache_index() const noexcept { return m_index; }
  void set_cache_index(std::uint8_t idx) noexcept { m_index = idx; }
  const std::array<result_type, CACHE_SIZE> &cache() const noexcept { return m_cache; }
  std::array<result_type, CACHE_SIZE> &cache() noexcept { return m_cache; }

  PRNG_ALWAYS_INLINE std::size_t getSIMDSize() const noexcept { return std::size_t{State::SIMD_WIDTH}; }

private:
  using State = internal::PhiloxState<xsimd::best_arch, N, W, R>;
  alignas(State::simd_type::arch_type::alignment()) std::array<result_type, CACHE_SIZE> m_cache{};
  State m_state;
  std::uint8_t m_index = 0;
};
#endif // XSIMD_NO_SUPPORTED_ARCHITECTURE

// Convenience aliases
using Philox4x32SIMD = PhiloxSIMD<4, 32, 10>;
using Philox2x32SIMD = PhiloxSIMD<2, 32, 10>;
using Philox4x64SIMD = PhiloxSIMD<4, 64, 10>;
using Philox2x64SIMD = PhiloxSIMD<2, 64, 10>;

#ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE
using Philox4x32Native = PhiloxNative<4, 32, 10>;
using Philox2x32Native = PhiloxNative<2, 32, 10>;
using Philox4x64Native = PhiloxNative<4, 64, 10>;
using Philox2x64Native = PhiloxNative<2, 64, 10>;
#endif // XSIMD_NO_SUPPORTED_ARCHITECTURE

} // namespace prng