multprec.cpp
: Simplify with _Unsigned128
by StephanTLavavej · Pull Request #5473 · microsoft/STL (original) (raw)
Followup to #5436, which reimplemented linear_congruential_engine
in <random>
by using our modern _Unsigned128
. This PR shrinks our separately compiled, now-retained-for-bincompat multprec.cpp
by also using _Unsigned128
. Ordinarily we don't put a lot of effort into simplifying retained-for-bincompat code because the benefit is low and the risk is unnecessary. In this case, it's worth the risk because this fixes #1008 by removing squirrelly code that needed an /analyze
suppression. (Being /analyze
-clean is very important now, and keeping suppressions around in our codebase is a minor debt.)
My new comments do a hopefully better job than before of explaining _MP_arr
's representation, and the assumptions behind the various functions. I am aiming for simplicity here, not micro-optimizations; e.g. _MP_Rem()
is computing modulo a 64-bit value, so only the lower 64 bits will be non-zero, but I'm still going through the general assign_mp_from_u128()
helper function.
Retained-for-bincompat code isn't something that we expend effort on testing (the idea is that we don't mess with it and it keeps working). In this case, I performed manual randomized testing. I wrote a test case to generate random values, subject to the conditions in linear_congruential_engine
, and compared not only the ultimate output but also the intermediate _MP_arr
contents. I compared the new implementation (after building the STL and running set_environment.bat
, of course) to a copy of the old classic implementation, for 12.8 billion trials, and everything agreed.
Click to expand test case:
// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include #include #include #include #include #include #include #include #include #include using namespace std; using namespace std::chrono;
namespace { using _MP_arr = uint64_t[5]; constexpr int _MP_len = 5;
// implements multiprecision math for random number generators
constexpr int shift = numeric_limits<unsigned long long>::digits / 2;
constexpr unsigned long long mask = ~(~0ULL << shift);
constexpr unsigned long long maxVal = mask + 1;
[[nodiscard]] unsigned long long __CLRCALL_PURE_OR_CDECL old_Get(
_MP_arr u) noexcept { // convert multi-word value to scalar value
return (u[1] << shift) + u[0];
}
static void add(unsigned long long* u, int ulen, unsigned long long* v,
int vlen) noexcept { // add multi-word value to multi-word value
int i;
unsigned long long k = 0;
for (i = 0; i < vlen; ++i) { // add multi-word values
u[i] += v[i] + k;
k = u[i] >> shift;
u[i] &= mask;
}
for (; k != 0 && i < ulen; ++i) { // propagate carry
u[i] += k;
k = u[i] >> shift;
u[i] &= mask;
}
}
void __CLRCALL_PURE_OR_CDECL old_Add(
_MP_arr u, unsigned long long v0) noexcept { // add scalar value to multi-word value
unsigned long long v[2];
v[0] = v0 & mask;
v[1] = v0 >> shift;
add(u, _MP_len, v, 2);
}
static void mul(unsigned long long* u, int ulen,
unsigned long long v0) noexcept { // multiply multi-word value by single-word value
unsigned long long k = 0;
for (int i = 0; i < ulen; ++i) { // multiply and propagate carry
u[i] = u[i] * v0 + k;
k = u[i] >> shift;
u[i] &= mask;
}
}
void __CLRCALL_PURE_OR_CDECL old_Mul(_MP_arr w, unsigned long long u0,
unsigned long long v0) noexcept { // multiply multi-word value by multi-word value
constexpr int m = 2;
constexpr int n = 2;
unsigned long long u[2];
unsigned long long v[2];
u[0] = u0 & mask;
u[1] = u0 >> shift;
v[0] = v0 & mask;
v[1] = v0 >> shift;
// Knuth, vol. 2, p. 268, Algorithm M
// M1: [Initialize.]
for (int i = 0; i < m + n + 1; ++i) {
w[i] = 0;
}
for (int j = 0; j < n; ++j) { // M2: [Zero multiplier?]
if (v[j] == 0) {
w[j + m] = 0;
} else { // multiply by non-zero value
unsigned long long k = 0;
int i;
// M3: [Initialize i.]
for (i = 0; i < m; ++i) { // M4: [Multiply and add.]
w[i + j] = u[i] * v[j] + w[i + j] + k;
k = w[i + j] >> shift;
w[i + j] &= mask;
// M5: [Loop on i.]
}
w[i + j] = k;
}
// M6: [Loop on j.]
}
}
static void div(_MP_arr u,
unsigned long long
v0) noexcept { // divide multi-word value by scalar value (fits in lower half of unsigned long long)
unsigned long long k = 0;
int ulen = _MP_len;
while (0 <= --ulen) { // propagate remainder and divide
unsigned long long tmp = (k << shift) + u[ulen];
u[ulen] = tmp / v0;
k = tmp % v0;
}
}
[[nodiscard]] static int limit(const unsigned long long* u, int ulen) noexcept { // get index of last non-zero value
while (u[ulen - 1] == 0) {
--ulen;
}
return ulen;
}
void __CLRCALL_PURE_OR_CDECL old_Rem(
_MP_arr u, unsigned long long v0) noexcept { // divide multi-word value by value, leaving remainder in u
unsigned long long v[2];
v[0] = v0 & mask;
v[1] = v0 >> shift;
const int n = limit(v, 2);
_Analysis_assume_(n > 0);
_Analysis_assume_(n <= 2);
const int m = limit(u, _MP_len) - n;
_Analysis_assume_(m > 0);
_Analysis_assume_(m <= _MP_len - n);
// Knuth, vol. 2, p. 272, Algorithm D
// D1: [Normalize.]
unsigned long long d = maxVal / (v[n - 1] + 1);
if (d != 1) { // scale numerator and divisor
mul(u, _MP_len, d);
mul(v, n, d);
}
// D2: [Initialize j.]
for (int j = m; 0 <= j; --j) { // D3: [Calculate qh.]
unsigned long long qh = ((u[j + n] << shift) + u[j + n - 1]) / v[n - 1];
if (qh == 0) {
continue;
}
unsigned long long rh = ((u[j + n] << shift) + u[j + n - 1]) % v[n - 1];
for (;;) {
#pragma warning(push) #pragma warning(disable : 6385) // TRANSITION, GH-1008 if (qh < maxVal && qh * v[n - 2] <= (rh << shift) + u[j + n - 2]) { #pragma warning(pop) break; } else { // reduce tentative value and retry --qh; rh += v[n - 1]; if (maxVal <= rh) { break; } } }
// D4: [Multiply and subtract.]
unsigned long long k = 0;
int i;
for (i = 0; i < n; ++i) { // multiply and subtract
u[j + i] -= qh * v[i] + k;
k = u[j + i] >> shift;
if (k) {
k = maxVal - k;
}
u[j + i] &= mask;
}
for (; k != 0 && j + i < _MP_len; ++i) { // propagate borrow
u[j + i] -= k;
k = u[j + i] >> shift;
if (k) {
k = maxVal - k;
}
u[j + i] &= mask;
}
// D5: [Test remainder.]
if (k != 0) { // D6: [Add back.]
--qh;
add(u + j, n + 1, v, n);
}
// D7: [Loop on j.]
}
// D8: [Unnormalize.]
if (d != 1) {
div(u, d);
}
}
} // unnamed namespace
_STD_BEGIN [[nodiscard]] _CRTIMP2_PURE uint64_t __CLRCALL_PURE_OR_CDECL _MP_Get(_MP_arr _Wx) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Add(_MP_arr _Wx, uint64_t _Cx) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Mul(_MP_arr _Wx, uint64_t _Prev, uint64_t _Ax) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Rem(_MP_arr _Wx, uint64_t _Mx) noexcept; _STD_END
int main() { atomic atom_actual_trials{0};
constexpr uint32_t num_threads = 32;
constexpr uint64_t num_trials_per_thread = 400'000'000;
vector<jthread> worker_threads;
worker_threads.reserve(num_threads);
const auto start = steady_clock::now();
for (uint32_t worker = 0; worker < num_threads; ++worker) {
worker_threads.emplace_back([worker, &atom_actual_trials] {
mt19937_64 gen{1729 + worker};
uint64_t local_actual_trials = 0;
for (uint64_t trials = 0; trials < num_trials_per_thread; ++trials) {
// `_Mx == 0` is special-cased; `_Mx == 1` would force `_Ax == 0` (see below)
uniform_int_distribution<uint64_t> mx_dist{2, UINT64_MAX};
const uint64_t mx = mx_dist(gen);
// N5008 [rand.eng.lcong]/3: "If the template parameter m is not 0,
// the following relations shall hold: a < m and c < m."
uniform_int_distribution<uint64_t> ax_dist{1, mx - 1}; // `_Ax == 0` is special-cased
const uint64_t ax = ax_dist(gen);
uniform_int_distribution<uint64_t> cx_dist{0, mx - 1};
const uint64_t cx = cx_dist(gen);
uniform_int_distribution<uint64_t> prev_dist{0, mx - 1}; // the result is always taken modulo `_Mx`
const uint64_t prev = prev_dist(gen);
if (cx <= UINT32_MAX && mx - 1 <= (UINT32_MAX - cx) / ax) {
continue;
} else if (mx - 1 <= (UINT64_MAX - cx) / ax) {
continue;
} else {
++local_actual_trials;
_MP_arr wx_old;
_MP_arr wx_new;
old_Mul(wx_old, prev, ax);
_MP_Mul(wx_new, prev, ax);
assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);
old_Add(wx_old, cx);
_MP_Add(wx_new, cx);
assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);
old_Rem(wx_old, mx);
_MP_Rem(wx_new, mx);
assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);
const auto result_old = old_Get(wx_old);
const auto result_new = _MP_Get(wx_new);
assert(result_old == result_new);
}
}
atom_actual_trials += local_actual_trials;
});
}
worker_threads.clear();
const auto finish = steady_clock::now();
println("num_threads: {}", num_threads);
println("num_trials_per_thread: {}", num_trials_per_thread);
println("atom_actual_trials: {}", atom_actual_trials.load());
println("Time: {} ms", duration_cast<milliseconds>(finish - start).count());
}
C:\Temp>cl /EHsc /nologo /W4 /std:c++latest /MT /O2 /GL meow.cpp
meow.cpp
Generating code
Finished generating code
C:\Temp>meow
num_threads: 32
num_trials_per_thread: 400000000
atom_actual_trials: 12799999995
Time: 58741 ms