Improve isqrt
tests and add benchmarks · patricklam/verify-rust-std@2139651 (original) (raw)
``
1
`+
macro_rules! tests {
`
``
2
`+
($isqrt_consistency_check_fn_macro:ident : (((T:ident)+) => {
`
``
3
`+
$(
`
``
4
`+
mod $T {
`
``
5
`+
isqrtconsistencycheckfnmacro!(isqrt_consistency_check_fn_macro!(isqrtconsistencycheckfnmacro!(T);
`
``
6
+
``
7
`+
// Check that the following produce the correct values from
`
``
8
`` +
// isqrt
:
``
``
9
`+
//
`
``
10
`+
// * the first and last 128 nonnegative values
`
``
11
`+
// * powers of two, minus one
`
``
12
`+
// * powers of two
`
``
13
`+
//
`
``
14
`` +
// For signed types, check that checked_isqrt
and isqrt
``
``
15
`+
// either produce the same numeric value or respectively
`
``
16
`` +
// produce None
and a panic. Make sure to do a consistency
``
``
17
`` +
// check for <$T>::MIN
as well, as no nonnegative values
``
``
18
`+
// negate to it.
`
``
19
`+
//
`
``
20
`` +
// For unsigned types check that isqrt
produces the same
``
``
21
`` +
// numeric value for $T
and NonZero<$T>
.
``
``
22
`+
#[test]
`
``
23
`+
fn isqrt() {
`
``
24
`+
isqrt_consistency_check(<$T>::MIN);
`
``
25
+
``
26
`+
for n in (0..=127)
`
``
27
`+
.chain(<$T>::MAX - 127..=<$T>::MAX)
`
``
28
`+
.chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1))
`
``
29
`+
.chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent))
`
``
30
`+
{
`
``
31
`+
isqrt_consistency_check(n);
`
``
32
+
``
33
`+
let isqrt_n = n.isqrt();
`
``
34
`+
assert!(
`
``
35
`+
isqrt_n
`
``
36
`+
.checked_mul(isqrt_n)
`
``
37
`+
.map(|isqrt_n_squared| isqrt_n_squared <= n)
`
``
38
`+
.unwrap_or(false),
`
``
39
`` +
"{n}.isqrt()
should be lower than {isqrt_n}."
``
``
40
`+
);
`
``
41
`+
assert!(
`
``
42
`+
(isqrt_n + 1)
`
``
43
`+
.checked_mul(isqrt_n + 1)
`
``
44
`+
.map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared)
`
``
45
`+
.unwrap_or(true),
`
``
46
`` +
"{n}.isqrt()
should be higher than {isqrt_n})."
``
``
47
`+
);
`
``
48
`+
}
`
``
49
`+
}
`
``
50
+
``
51
`+
// Check the square roots of:
`
``
52
`+
//
`
``
53
`+
// * the first 1,024 perfect squares
`
``
54
`+
// * halfway between each of the first 1,024 perfect squares
`
``
55
`+
// and the next perfect square
`
``
56
`+
// * the next perfect square after the each of the first 1,024
`
``
57
`+
// perfect squares, minus one
`
``
58
`+
// * the last 1,024 perfect squares
`
``
59
`+
// * the last 1,024 perfect squares, minus one
`
``
60
`+
// * halfway between each of the last 1,024 perfect squares
`
``
61
`+
// and the previous perfect square
`
``
62
`+
#[test]
`
``
63
`+
// Skip this test on Miri, as it takes too long to run.
`
``
64
`+
#[cfg(not(miri))]
`
``
65
`+
fn isqrt_extended() {
`
``
66
`+
// The correct value is worked out by using the fact that
`
``
67
`+
// the nth nonzero perfect square is the sum of the first n
`
``
68
`+
// odd numbers:
`
``
69
`+
//
`
``
70
`+
// 1 = 1
`
``
71
`+
// 4 = 1 + 3
`
``
72
`+
// 9 = 1 + 3 + 5
`
``
73
`+
// 16 = 1 + 3 + 5 + 7
`
``
74
`+
//
`
``
75
`+
// Note also that the last odd number added in is two times
`
``
76
`+
// the square root of the previous perfect square, plus
`
``
77
`+
// one:
`
``
78
`+
//
`
``
79
`+
// 1 = 2*0 + 1
`
``
80
`+
// 3 = 2*1 + 1
`
``
81
`+
// 5 = 2*2 + 1
`
``
82
`+
// 7 = 2*3 + 1
`
``
83
`+
//
`
``
84
`+
// That means we can add the square root of this perfect
`
``
85
`+
// square once to get about halfway to the next perfect
`
``
86
`+
// square, then we can add the square root of this perfect
`
``
87
`+
// square again to get to the next perfect square, minus
`
``
88
`+
// one, then we can add one to get to the next perfect
`
``
89
`+
// square.
`
``
90
`+
//
`
``
91
`+
// This allows us to, for each of the first 1,024 perfect
`
``
92
`+
// squares, test that the square roots of the following are
`
``
93
`+
// all correct and equal to each other:
`
``
94
`+
//
`
``
95
`+
// * the current perfect square
`
``
96
`+
// * about halfway to the next perfect square
`
``
97
`+
// * the next perfect square, minus one
`
``
98
`+
let mut n: $T = 0;
`
``
99
`+
for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T {
`
``
100
`+
isqrt_consistency_check(n);
`
``
101
`+
assert_eq!(
`
``
102
`+
n.isqrt(),
`
``
103
`+
sqrt_n,
`
``
104
`` +
"{sqrt_n}.pow(2).isqrt()
should be {sqrt_n}."
``
``
105
`+
);
`
``
106
+
``
107
`+
n += sqrt_n;
`
``
108
`+
isqrt_consistency_check(n);
`
``
109
`+
assert_eq!(
`
``
110
`+
n.isqrt(),
`
``
111
`+
sqrt_n,
`
``
112
`` +
"{n} is about halfway between {sqrt_n}.pow(2)
and {}.pow(2)
, so {n}.isqrt()
should be {sqrt_n}.",
``
``
113
`+
sqrt_n + 1
`
``
114
`+
);
`
``
115
+
``
116
`+
n += sqrt_n;
`
``
117
`+
isqrt_consistency_check(n);
`
``
118
`+
assert_eq!(
`
``
119
`+
n.isqrt(),
`
``
120
`+
sqrt_n,
`
``
121
`` +
"({}.pow(2) - 1).isqrt()
should be {sqrt_n}.",
``
``
122
`+
sqrt_n + 1
`
``
123
`+
);
`
``
124
+
``
125
`+
n += 1;
`
``
126
`+
}
`
``
127
+
``
128
`+
// Similarly, for each of the last 1,024 perfect squares,
`
``
129
`+
// check:
`
``
130
`+
//
`
``
131
`+
// * the current perfect square
`
``
132
`+
// * the current perfect square, minus one
`
``
133
`+
// * about halfway to the previous perfect square
`
``
134
`+
//
`
``
135
`` +
// MAX
's isqrt
return value is verified in the isqrt
``
``
136
`+
// test function above.
`
``
137
`+
let maximum_sqrt = <$T>::MAX.isqrt();
`
``
138
`+
let mut n = maximum_sqrt * maximum_sqrt;
`
``
139
+
``
140
`+
for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() {
`
``
141
`+
isqrt_consistency_check(n);
`
``
142
`+
assert_eq!(
`
``
143
`+
n.isqrt(),
`
``
144
`+
sqrt_n + 1,
`
``
145
`` +
"{0}.pow(2).isqrt()
should be {0}.",
``
``
146
`+
sqrt_n + 1
`
``
147
`+
);
`
``
148
+
``
149
`+
n -= 1;
`
``
150
`+
isqrt_consistency_check(n);
`
``
151
`+
assert_eq!(
`
``
152
`+
n.isqrt(),
`
``
153
`+
sqrt_n,
`
``
154
`` +
"({}.pow(2) - 1).isqrt()
should be {sqrt_n}.",
``
``
155
`+
sqrt_n + 1
`
``
156
`+
);
`
``
157
+
``
158
`+
n -= sqrt_n;
`
``
159
`+
isqrt_consistency_check(n);
`
``
160
`+
assert_eq!(
`
``
161
`+
n.isqrt(),
`
``
162
`+
sqrt_n,
`
``
163
`` +
"{n} is about halfway between {sqrt_n}.pow(2)
and {}.pow(2)
, so {n}.isqrt()
should be {sqrt_n}.",
``
``
164
`+
sqrt_n + 1
`
``
165
`+
);
`
``
166
+
``
167
`+
n -= sqrt_n;
`
``
168
`+
}
`
``
169
`+
}
`
``
170
`+
}
`
``
171
`+
)*
`
``
172
`+
};
`
``
173
`+
}
`
``
174
+
``
175
`+
macro_rules! signed_check {
`
``
176
`+
($T:ident) => {
`
``
177
`+
/// This takes an input and, if it's nonnegative or
`
``
178
`` +
#[doc = concat!("", stringify!($T), "::MIN
,")]
``
``
179
`` +
/// checks that isqrt
and checked_isqrt
produce equivalent results
``
``
180
`+
/// for that input and for the negative of that input.
`
``
181
`+
///
`
``
182
`+
/// # Note
`
``
183
`+
///
`
``
184
`` +
/// This cannot check that negative inputs to isqrt
cause panics if
``
``
185
`+
/// panics abort instead of unwind.
`
``
186
`+
fn isqrt_consistency_check(n: $T) {
`
``
187
`` +
// <$T>::MIN
will be negative, so ignore it in this nonnegative
``
``
188
`+
// section.
`
``
189
`+
if n >= 0 {
`
``
190
`+
assert_eq!(
`
``
191
`+
Some(n.isqrt()),
`
``
192
`+
n.checked_isqrt(),
`
``
193
`` +
"{n}.checked_isqrt()
should match Some({n}.isqrt())
.",
``
``
194
`+
);
`
``
195
`+
}
`
``
196
+
``
197
`` +
// wrapping_neg
so that <$T>::MIN
will negate to itself rather
``
``
198
`+
// than panicking.
`
``
199
`+
let negative_n = n.wrapping_neg();
`
``
200
+
``
201
`+
// Zero negated will still be nonnegative, so ignore it in this
`
``
202
`+
// negative section.
`
``
203
`+
if negative_n < 0 {
`
``
204
`+
assert_eq!(
`
``
205
`+
negative_n.checked_isqrt(),
`
``
206
`+
None,
`
``
207
`` +
"({negative_n}).checked_isqrt()
should be None
, as {negative_n} is negative.",
``
``
208
`+
);
`
``
209
+
``
210
`` +
// catch_unwind
only works when panics unwind rather than abort.
``
``
211
`+
#[cfg(panic = "unwind")]
`
``
212
`+
{
`
``
213
`+
std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err(
`
``
214
`` +
&format!("({negative_n}).isqrt()
should have panicked, as {negative_n} is negative.")
``
``
215
`+
);
`
``
216
`+
}
`
``
217
`+
}
`
``
218
`+
}
`
``
219
`+
};
`
``
220
`+
}
`
``
221
+
``
222
`+
macro_rules! unsigned_check {
`
``
223
`+
($T:ident) => {
`
``
224
`` +
/// This takes an input and, if it's nonzero, checks that isqrt
``
``
225
`+
/// produces the same numeric value for both
`
``
226
`` +
#[doc = concat!("", stringify!($T), "
and ")]
``
``
227
`` +
#[doc = concat!("NonZero<", stringify!($T), ">
.")]
``
``
228
`+
fn isqrt_consistency_check(n: $T) {
`
``
229
`` +
// Zero cannot be turned into a NonZero
value, so ignore it in
``
``
230
`+
// this nonzero section.
`
``
231
`+
if n > 0 {
`
``
232
`+
assert_eq!(
`
``
233
`+
n.isqrt(),
`
``
234
`+
core::num::NonZero::<$T>::new(n)
`
``
235
`+
.expect(
`
``
236
`` +
"Was not able to create a new NonZero
value from a nonzero number."
``
``
237
`+
)
`
``
238
`+
.isqrt()
`
``
239
`+
.get(),
`
``
240
`` +
"{n}.isqrt
should match NonZero
's {n}.isqrt().get()
.",
``
``
241
`+
);
`
``
242
`+
}
`
``
243
`+
}
`
``
244
`+
};
`
``
245
`+
}
`
``
246
+
``
247
`+
tests!(signed_check: i8 i16 i32 i64 i128);
`
``
248
`+
tests!(unsigned_check: u8 u16 u32 u64 u128);
`