multprec.cpp: Simplify with _Unsigned128 by StephanTLavavej · Pull Request #5473 · microsoft/STL (original) (raw)

Followup to #5436, which reimplemented linear_congruential_engine in <random> by using our modern _Unsigned128. This PR shrinks our separately compiled, now-retained-for-bincompat multprec.cpp by also using _Unsigned128. Ordinarily we don't put a lot of effort into simplifying retained-for-bincompat code because the benefit is low and the risk is unnecessary. In this case, it's worth the risk because this fixes #1008 by removing squirrelly code that needed an /analyze suppression. (Being /analyze-clean is very important now, and keeping suppressions around in our codebase is a minor debt.)

My new comments do a hopefully better job than before of explaining _MP_arr's representation, and the assumptions behind the various functions. I am aiming for simplicity here, not micro-optimizations; e.g. _MP_Rem() is computing modulo a 64-bit value, so only the lower 64 bits will be non-zero, but I'm still going through the general assign_mp_from_u128() helper function.

Retained-for-bincompat code isn't something that we expend effort on testing (the idea is that we don't mess with it and it keeps working). In this case, I performed manual randomized testing. I wrote a test case to generate random values, subject to the conditions in linear_congruential_engine, and compared not only the ultimate output but also the intermediate _MP_arr contents. I compared the new implementation (after building the STL and running set_environment.bat, of course) to a copy of the old classic implementation, for 12.8 billion trials, and everything agreed.

Click to expand test case:

// Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include #include #include #include #include #include #include #include #include #include using namespace std; using namespace std::chrono;

namespace { using _MP_arr = uint64_t[5]; constexpr int _MP_len = 5;

// implements multiprecision math for random number generators

constexpr int shift                 = numeric_limits<unsigned long long>::digits / 2;
constexpr unsigned long long mask   = ~(~0ULL << shift);
constexpr unsigned long long maxVal = mask + 1;

[[nodiscard]] unsigned long long __CLRCALL_PURE_OR_CDECL old_Get(
    _MP_arr u) noexcept { // convert multi-word value to scalar value
    return (u[1] << shift) + u[0];
}

static void add(unsigned long long* u, int ulen, unsigned long long* v,
    int vlen) noexcept { // add multi-word value to multi-word value
    int i;
    unsigned long long k = 0;
    for (i = 0; i < vlen; ++i) { // add multi-word values
        u[i] += v[i] + k;
        k = u[i] >> shift;
        u[i] &= mask;
    }
    for (; k != 0 && i < ulen; ++i) { // propagate carry
        u[i] += k;
        k = u[i] >> shift;
        u[i] &= mask;
    }
}

void __CLRCALL_PURE_OR_CDECL old_Add(
    _MP_arr u, unsigned long long v0) noexcept { // add scalar value to multi-word value
    unsigned long long v[2];
    v[0] = v0 & mask;
    v[1] = v0 >> shift;
    add(u, _MP_len, v, 2);
}

static void mul(unsigned long long* u, int ulen,
    unsigned long long v0) noexcept { // multiply multi-word value by single-word value
    unsigned long long k = 0;
    for (int i = 0; i < ulen; ++i) { // multiply and propagate carry
        u[i] = u[i] * v0 + k;
        k    = u[i] >> shift;
        u[i] &= mask;
    }
}

void __CLRCALL_PURE_OR_CDECL old_Mul(_MP_arr w, unsigned long long u0,
    unsigned long long v0) noexcept { // multiply multi-word value by multi-word value
    constexpr int m = 2;
    constexpr int n = 2;
    unsigned long long u[2];
    unsigned long long v[2];
    u[0] = u0 & mask;
    u[1] = u0 >> shift;
    v[0] = v0 & mask;
    v[1] = v0 >> shift;

    // Knuth, vol. 2, p. 268, Algorithm M
    // M1: [Initialize.]
    for (int i = 0; i < m + n + 1; ++i) {
        w[i] = 0;
    }

    for (int j = 0; j < n; ++j) { // M2: [Zero multiplier?]
        if (v[j] == 0) {
            w[j + m] = 0;
        } else { // multiply by non-zero value
            unsigned long long k = 0;
            int i;
            // M3: [Initialize i.]
            for (i = 0; i < m; ++i) { // M4: [Multiply and add.]
                w[i + j] = u[i] * v[j] + w[i + j] + k;
                k        = w[i + j] >> shift;
                w[i + j] &= mask;
                // M5: [Loop on i.]
            }
            w[i + j] = k;
        }
        // M6: [Loop on j.]
    }
}

static void div(_MP_arr u,
    unsigned long long
        v0) noexcept { // divide multi-word value by scalar value (fits in lower half of unsigned long long)
    unsigned long long k = 0;
    int ulen             = _MP_len;
    while (0 <= --ulen) { // propagate remainder and divide
        unsigned long long tmp = (k << shift) + u[ulen];
        u[ulen]                = tmp / v0;
        k                      = tmp % v0;
    }
}

[[nodiscard]] static int limit(const unsigned long long* u, int ulen) noexcept { // get index of last non-zero value
    while (u[ulen - 1] == 0) {
        --ulen;
    }

    return ulen;
}

void __CLRCALL_PURE_OR_CDECL old_Rem(
    _MP_arr u, unsigned long long v0) noexcept { // divide multi-word value by value, leaving remainder in u
    unsigned long long v[2];
    v[0]        = v0 & mask;
    v[1]        = v0 >> shift;
    const int n = limit(v, 2);
    _Analysis_assume_(n > 0);
    _Analysis_assume_(n <= 2);
    const int m = limit(u, _MP_len) - n;
    _Analysis_assume_(m > 0);
    _Analysis_assume_(m <= _MP_len - n);

    // Knuth, vol. 2, p. 272, Algorithm D
    // D1: [Normalize.]
    unsigned long long d = maxVal / (v[n - 1] + 1);
    if (d != 1) { // scale numerator and divisor
        mul(u, _MP_len, d);
        mul(v, n, d);
    }
    // D2: [Initialize j.]
    for (int j = m; 0 <= j; --j) { // D3: [Calculate qh.]
        unsigned long long qh = ((u[j + n] << shift) + u[j + n - 1]) / v[n - 1];
        if (qh == 0) {
            continue;
        }

        unsigned long long rh = ((u[j + n] << shift) + u[j + n - 1]) % v[n - 1];
        for (;;) {

#pragma warning(push) #pragma warning(disable : 6385) // TRANSITION, GH-1008 if (qh < maxVal && qh * v[n - 2] <= (rh << shift) + u[j + n - 2]) { #pragma warning(pop) break; } else { // reduce tentative value and retry --qh; rh += v[n - 1]; if (maxVal <= rh) { break; } } }

        // D4: [Multiply and subtract.]
        unsigned long long k = 0;
        int i;
        for (i = 0; i < n; ++i) { // multiply and subtract
            u[j + i] -= qh * v[i] + k;
            k = u[j + i] >> shift;
            if (k) {
                k = maxVal - k;
            }

            u[j + i] &= mask;
        }
        for (; k != 0 && j + i < _MP_len; ++i) { // propagate borrow
            u[j + i] -= k;
            k = u[j + i] >> shift;
            if (k) {
                k = maxVal - k;
            }

            u[j + i] &= mask;
        }
        // D5: [Test remainder.]
        if (k != 0) { // D6: [Add back.]
            --qh;
            add(u + j, n + 1, v, n);
        }
        // D7: [Loop on j.]
    }
    // D8: [Unnormalize.]
    if (d != 1) {
        div(u, d);
    }
}

} // unnamed namespace

_STD_BEGIN [[nodiscard]] _CRTIMP2_PURE uint64_t __CLRCALL_PURE_OR_CDECL _MP_Get(_MP_arr _Wx) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Add(_MP_arr _Wx, uint64_t _Cx) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Mul(_MP_arr _Wx, uint64_t _Prev, uint64_t _Ax) noexcept; _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _MP_Rem(_MP_arr _Wx, uint64_t _Mx) noexcept; _STD_END

int main() { atomic atom_actual_trials{0};

constexpr uint32_t num_threads           = 32;
constexpr uint64_t num_trials_per_thread = 400'000'000;

vector<jthread> worker_threads;
worker_threads.reserve(num_threads);

const auto start = steady_clock::now();
for (uint32_t worker = 0; worker < num_threads; ++worker) {
    worker_threads.emplace_back([worker, &atom_actual_trials] {
        mt19937_64 gen{1729 + worker};

        uint64_t local_actual_trials = 0;

        for (uint64_t trials = 0; trials < num_trials_per_thread; ++trials) {
            // `_Mx == 0` is special-cased; `_Mx == 1` would force `_Ax == 0` (see below)
            uniform_int_distribution<uint64_t> mx_dist{2, UINT64_MAX};
            const uint64_t mx = mx_dist(gen);

            // N5008 [rand.eng.lcong]/3: "If the template parameter m is not 0,
            // the following relations shall hold: a < m and c < m."

            uniform_int_distribution<uint64_t> ax_dist{1, mx - 1}; // `_Ax == 0` is special-cased
            const uint64_t ax = ax_dist(gen);

            uniform_int_distribution<uint64_t> cx_dist{0, mx - 1};
            const uint64_t cx = cx_dist(gen);

            uniform_int_distribution<uint64_t> prev_dist{0, mx - 1}; // the result is always taken modulo `_Mx`
            const uint64_t prev = prev_dist(gen);

            if (cx <= UINT32_MAX && mx - 1 <= (UINT32_MAX - cx) / ax) {
                continue;
            } else if (mx - 1 <= (UINT64_MAX - cx) / ax) {
                continue;
            } else {
                ++local_actual_trials;

                _MP_arr wx_old;
                _MP_arr wx_new;

                old_Mul(wx_old, prev, ax);
                _MP_Mul(wx_new, prev, ax);
                assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);

                old_Add(wx_old, cx);
                _MP_Add(wx_new, cx);
                assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);

                old_Rem(wx_old, mx);
                _MP_Rem(wx_new, mx);
                assert(memcmp(wx_old, wx_new, sizeof(_MP_arr)) == 0);

                const auto result_old = old_Get(wx_old);
                const auto result_new = _MP_Get(wx_new);
                assert(result_old == result_new);
            }
        }

        atom_actual_trials += local_actual_trials;
    });
}
worker_threads.clear();
const auto finish = steady_clock::now();

println("num_threads: {}", num_threads);
println("num_trials_per_thread: {}", num_trials_per_thread);
println("atom_actual_trials: {}", atom_actual_trials.load());
println("Time: {} ms", duration_cast<milliseconds>(finish - start).count());

}

C:\Temp>cl /EHsc /nologo /W4 /std:c++latest /MT /O2 /GL meow.cpp
meow.cpp
Generating code
Finished generating code

C:\Temp>meow
num_threads: 32
num_trials_per_thread: 400000000
atom_actual_trials: 12799999995
Time: 58741 ms