Improve isqrt tests and add benchmarks · patricklam/verify-rust-std@2139651 (original) (raw)

``

1

`+

macro_rules! tests {

`

``

2

`+

($isqrt_consistency_check_fn_macro:ident : (((T:ident)+) => {

`

``

3

`+

$(

`

``

4

`+

mod $T {

`

``

5

`+

isqrtconsistencycheckfnmacro!(isqrt_consistency_check_fn_macro!(isqrtconsistencycheckfnmacro!(T);

`

``

6

+

``

7

`+

// Check that the following produce the correct values from

`

``

8

`` +

// isqrt:

``

``

9

`+

//

`

``

10

`+

// * the first and last 128 nonnegative values

`

``

11

`+

// * powers of two, minus one

`

``

12

`+

// * powers of two

`

``

13

`+

//

`

``

14

`` +

// For signed types, check that checked_isqrt and isqrt

``

``

15

`+

// either produce the same numeric value or respectively

`

``

16

`` +

// produce None and a panic. Make sure to do a consistency

``

``

17

`` +

// check for <$T>::MIN as well, as no nonnegative values

``

``

18

`+

// negate to it.

`

``

19

`+

//

`

``

20

`` +

// For unsigned types check that isqrt produces the same

``

``

21

`` +

// numeric value for $T and NonZero<$T>.

``

``

22

`+

#[test]

`

``

23

`+

fn isqrt() {

`

``

24

`+

isqrt_consistency_check(<$T>::MIN);

`

``

25

+

``

26

`+

for n in (0..=127)

`

``

27

`+

.chain(<$T>::MAX - 127..=<$T>::MAX)

`

``

28

`+

.chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1))

`

``

29

`+

.chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent))

`

``

30

`+

{

`

``

31

`+

isqrt_consistency_check(n);

`

``

32

+

``

33

`+

let isqrt_n = n.isqrt();

`

``

34

`+

assert!(

`

``

35

`+

isqrt_n

`

``

36

`+

.checked_mul(isqrt_n)

`

``

37

`+

.map(|isqrt_n_squared| isqrt_n_squared <= n)

`

``

38

`+

.unwrap_or(false),

`

``

39

`` +

"{n}.isqrt() should be lower than {isqrt_n}."

``

``

40

`+

);

`

``

41

`+

assert!(

`

``

42

`+

(isqrt_n + 1)

`

``

43

`+

.checked_mul(isqrt_n + 1)

`

``

44

`+

.map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared)

`

``

45

`+

.unwrap_or(true),

`

``

46

`` +

"{n}.isqrt() should be higher than {isqrt_n})."

``

``

47

`+

);

`

``

48

`+

}

`

``

49

`+

}

`

``

50

+

``

51

`+

// Check the square roots of:

`

``

52

`+

//

`

``

53

`+

// * the first 1,024 perfect squares

`

``

54

`+

// * halfway between each of the first 1,024 perfect squares

`

``

55

`+

// and the next perfect square

`

``

56

`+

// * the next perfect square after the each of the first 1,024

`

``

57

`+

// perfect squares, minus one

`

``

58

`+

// * the last 1,024 perfect squares

`

``

59

`+

// * the last 1,024 perfect squares, minus one

`

``

60

`+

// * halfway between each of the last 1,024 perfect squares

`

``

61

`+

// and the previous perfect square

`

``

62

`+

#[test]

`

``

63

`+

// Skip this test on Miri, as it takes too long to run.

`

``

64

`+

#[cfg(not(miri))]

`

``

65

`+

fn isqrt_extended() {

`

``

66

`+

// The correct value is worked out by using the fact that

`

``

67

`+

// the nth nonzero perfect square is the sum of the first n

`

``

68

`+

// odd numbers:

`

``

69

`+

//

`

``

70

`+

// 1 = 1

`

``

71

`+

// 4 = 1 + 3

`

``

72

`+

// 9 = 1 + 3 + 5

`

``

73

`+

// 16 = 1 + 3 + 5 + 7

`

``

74

`+

//

`

``

75

`+

// Note also that the last odd number added in is two times

`

``

76

`+

// the square root of the previous perfect square, plus

`

``

77

`+

// one:

`

``

78

`+

//

`

``

79

`+

// 1 = 2*0 + 1

`

``

80

`+

// 3 = 2*1 + 1

`

``

81

`+

// 5 = 2*2 + 1

`

``

82

`+

// 7 = 2*3 + 1

`

``

83

`+

//

`

``

84

`+

// That means we can add the square root of this perfect

`

``

85

`+

// square once to get about halfway to the next perfect

`

``

86

`+

// square, then we can add the square root of this perfect

`

``

87

`+

// square again to get to the next perfect square, minus

`

``

88

`+

// one, then we can add one to get to the next perfect

`

``

89

`+

// square.

`

``

90

`+

//

`

``

91

`+

// This allows us to, for each of the first 1,024 perfect

`

``

92

`+

// squares, test that the square roots of the following are

`

``

93

`+

// all correct and equal to each other:

`

``

94

`+

//

`

``

95

`+

// * the current perfect square

`

``

96

`+

// * about halfway to the next perfect square

`

``

97

`+

// * the next perfect square, minus one

`

``

98

`+

let mut n: $T = 0;

`

``

99

`+

for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T {

`

``

100

`+

isqrt_consistency_check(n);

`

``

101

`+

assert_eq!(

`

``

102

`+

n.isqrt(),

`

``

103

`+

sqrt_n,

`

``

104

`` +

"{sqrt_n}.pow(2).isqrt() should be {sqrt_n}."

``

``

105

`+

);

`

``

106

+

``

107

`+

n += sqrt_n;

`

``

108

`+

isqrt_consistency_check(n);

`

``

109

`+

assert_eq!(

`

``

110

`+

n.isqrt(),

`

``

111

`+

sqrt_n,

`

``

112

`` +

"{n} is about halfway between {sqrt_n}.pow(2) and {}.pow(2), so {n}.isqrt() should be {sqrt_n}.",

``

``

113

`+

sqrt_n + 1

`

``

114

`+

);

`

``

115

+

``

116

`+

n += sqrt_n;

`

``

117

`+

isqrt_consistency_check(n);

`

``

118

`+

assert_eq!(

`

``

119

`+

n.isqrt(),

`

``

120

`+

sqrt_n,

`

``

121

`` +

"({}.pow(2) - 1).isqrt() should be {sqrt_n}.",

``

``

122

`+

sqrt_n + 1

`

``

123

`+

);

`

``

124

+

``

125

`+

n += 1;

`

``

126

`+

}

`

``

127

+

``

128

`+

// Similarly, for each of the last 1,024 perfect squares,

`

``

129

`+

// check:

`

``

130

`+

//

`

``

131

`+

// * the current perfect square

`

``

132

`+

// * the current perfect square, minus one

`

``

133

`+

// * about halfway to the previous perfect square

`

``

134

`+

//

`

``

135

`` +

// MAX's isqrt return value is verified in the isqrt

``

``

136

`+

// test function above.

`

``

137

`+

let maximum_sqrt = <$T>::MAX.isqrt();

`

``

138

`+

let mut n = maximum_sqrt * maximum_sqrt;

`

``

139

+

``

140

`+

for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() {

`

``

141

`+

isqrt_consistency_check(n);

`

``

142

`+

assert_eq!(

`

``

143

`+

n.isqrt(),

`

``

144

`+

sqrt_n + 1,

`

``

145

`` +

"{0}.pow(2).isqrt() should be {0}.",

``

``

146

`+

sqrt_n + 1

`

``

147

`+

);

`

``

148

+

``

149

`+

n -= 1;

`

``

150

`+

isqrt_consistency_check(n);

`

``

151

`+

assert_eq!(

`

``

152

`+

n.isqrt(),

`

``

153

`+

sqrt_n,

`

``

154

`` +

"({}.pow(2) - 1).isqrt() should be {sqrt_n}.",

``

``

155

`+

sqrt_n + 1

`

``

156

`+

);

`

``

157

+

``

158

`+

n -= sqrt_n;

`

``

159

`+

isqrt_consistency_check(n);

`

``

160

`+

assert_eq!(

`

``

161

`+

n.isqrt(),

`

``

162

`+

sqrt_n,

`

``

163

`` +

"{n} is about halfway between {sqrt_n}.pow(2) and {}.pow(2), so {n}.isqrt() should be {sqrt_n}.",

``

``

164

`+

sqrt_n + 1

`

``

165

`+

);

`

``

166

+

``

167

`+

n -= sqrt_n;

`

``

168

`+

}

`

``

169

`+

}

`

``

170

`+

}

`

``

171

`+

)*

`

``

172

`+

};

`

``

173

`+

}

`

``

174

+

``

175

`+

macro_rules! signed_check {

`

``

176

`+

($T:ident) => {

`

``

177

`+

/// This takes an input and, if it's nonnegative or

`

``

178

`` +

#[doc = concat!("", stringify!($T), "::MIN,")]

``

``

179

`` +

/// checks that isqrt and checked_isqrt produce equivalent results

``

``

180

`+

/// for that input and for the negative of that input.

`

``

181

`+

///

`

``

182

`+

/// # Note

`

``

183

`+

///

`

``

184

`` +

/// This cannot check that negative inputs to isqrt cause panics if

``

``

185

`+

/// panics abort instead of unwind.

`

``

186

`+

fn isqrt_consistency_check(n: $T) {

`

``

187

`` +

// <$T>::MIN will be negative, so ignore it in this nonnegative

``

``

188

`+

// section.

`

``

189

`+

if n >= 0 {

`

``

190

`+

assert_eq!(

`

``

191

`+

Some(n.isqrt()),

`

``

192

`+

n.checked_isqrt(),

`

``

193

`` +

"{n}.checked_isqrt() should match Some({n}.isqrt()).",

``

``

194

`+

);

`

``

195

`+

}

`

``

196

+

``

197

`` +

// wrapping_neg so that <$T>::MIN will negate to itself rather

``

``

198

`+

// than panicking.

`

``

199

`+

let negative_n = n.wrapping_neg();

`

``

200

+

``

201

`+

// Zero negated will still be nonnegative, so ignore it in this

`

``

202

`+

// negative section.

`

``

203

`+

if negative_n < 0 {

`

``

204

`+

assert_eq!(

`

``

205

`+

negative_n.checked_isqrt(),

`

``

206

`+

None,

`

``

207

`` +

"({negative_n}).checked_isqrt() should be None, as {negative_n} is negative.",

``

``

208

`+

);

`

``

209

+

``

210

`` +

// catch_unwind only works when panics unwind rather than abort.

``

``

211

`+

#[cfg(panic = "unwind")]

`

``

212

`+

{

`

``

213

`+

std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err(

`

``

214

`` +

&format!("({negative_n}).isqrt() should have panicked, as {negative_n} is negative.")

``

``

215

`+

);

`

``

216

`+

}

`

``

217

`+

}

`

``

218

`+

}

`

``

219

`+

};

`

``

220

`+

}

`

``

221

+

``

222

`+

macro_rules! unsigned_check {

`

``

223

`+

($T:ident) => {

`

``

224

`` +

/// This takes an input and, if it's nonzero, checks that isqrt

``

``

225

`+

/// produces the same numeric value for both

`

``

226

`` +

#[doc = concat!("", stringify!($T), " and ")]

``

``

227

`` +

#[doc = concat!("NonZero<", stringify!($T), ">.")]

``

``

228

`+

fn isqrt_consistency_check(n: $T) {

`

``

229

`` +

// Zero cannot be turned into a NonZero value, so ignore it in

``

``

230

`+

// this nonzero section.

`

``

231

`+

if n > 0 {

`

``

232

`+

assert_eq!(

`

``

233

`+

n.isqrt(),

`

``

234

`+

core::num::NonZero::<$T>::new(n)

`

``

235

`+

.expect(

`

``

236

`` +

"Was not able to create a new NonZero value from a nonzero number."

``

``

237

`+

)

`

``

238

`+

.isqrt()

`

``

239

`+

.get(),

`

``

240

`` +

"{n}.isqrt should match NonZero's {n}.isqrt().get().",

``

``

241

`+

);

`

``

242

`+

}

`

``

243

`+

}

`

``

244

`+

};

`

``

245

`+

}

`

``

246

+

``

247

`+

tests!(signed_check: i8 i16 i32 i64 i128);

`

``

248

`+

tests!(unsigned_check: u8 u16 u32 u64 u128);

`