bpo-36887: add math.isqrt (GH-13244) · python/cpython@73934b9 (original) (raw)
`@@ -1476,6 +1476,266 @@ count_set_bits(unsigned long n)
`
1476
1476
`return count;
`
1477
1477
`}
`
1478
1478
``
``
1479
`+
/* Integer square root
`
``
1480
+
``
1481
`` +
Given a nonnegative integer n
, we want to compute the largest integer
``
``
1482
`` +
a
for which a * a <= n
, or equivalently the integer part of the exact
``
``
1483
`` +
square root of n
.
``
``
1484
+
``
1485
`+
We use an adaptive-precision pure-integer version of Newton's iteration. Given
`
``
1486
`` +
a positive integer n
, the algorithm produces at each iteration an integer
``
``
1487
`` +
approximation a
to the square root of n >> s
for some even integer s
,
``
``
1488
`` +
with s
decreasing as the iterations progress. On the final iteration, s
is
``
``
1489
`` +
zero and we have an approximation to the square root of n
itself.
``
``
1490
+
``
1491
`` +
At every step, the approximation a
is strictly within 1.0 of the true square
``
``
1492
`+
root, so we have
`
``
1493
+
``
1494
`+
(a - 1)**2 < (n >> s) < (a + 1)**2
`
``
1495
+
``
1496
`+
After the final iteration, a check-and-correct step is needed to determine
`
``
1497
`` +
whether a
or a - 1
gives the desired integer square root of n
.
``
``
1498
+
``
1499
`+
The algorithm is remarkable in its simplicity. There's no need for a
`
``
1500
`+
per-iteration check-and-correct step, and termination is straightforward: the
`
``
1501
`` +
number of iterations is known in advance (it's exactly floor(log2(log2(n)))
``
``
1502
`` +
for n > 1
). The only tricky part of the correctness proof is in establishing
``
``
1503
`` +
that the bound (a - 1)**2 < (n >> s) < (a + 1)**2
is maintained from one
``
``
1504
`+
iteration to the next. A sketch of the proof of this is given below.
`
``
1505
+
``
1506
`+
In addition to the proof sketch, a formal, computer-verified proof
`
``
1507
`+
of correctness (using Lean) of an equivalent recursive algorithm can be found
`
``
1508
`+
here:
`
``
1509
+
``
1510
`+
https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
`
``
1511
+
``
1512
+
``
1513
`+
Here's Python code equivalent to the C implementation below:
`
``
1514
+
``
1515
`+
def isqrt(n):
`
``
1516
`+
"""
`
``
1517
`+
Return the integer part of the square root of the input.
`
``
1518
`+
"""
`
``
1519
`+
n = operator.index(n)
`
``
1520
+
``
1521
`+
if n < 0:
`
``
1522
`+
raise ValueError("isqrt() argument must be nonnegative")
`
``
1523
`+
if n == 0:
`
``
1524
`+
return 0
`
``
1525
+
``
1526
`+
c = (n.bit_length() - 1) // 2
`
``
1527
`+
a = 1
`
``
1528
`+
d = 0
`
``
1529
`+
for s in reversed(range(c.bit_length())):
`
``
1530
`+
e = d
`
``
1531
`+
d = c >> s
`
``
1532
`+
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
`
``
1533
`+
assert (a-1)*2 < n >> 2(c - d) < (a+1)**2
`
``
1534
+
``
1535
`+
return a - (a*a > n)
`
``
1536
+
``
1537
+
``
1538
`+
Sketch of proof of correctness
`
``
1539
`+
`
``
1540
+
``
1541
`+
The delicate part of the correctness proof is showing that the loop invariant
`
``
1542
`+
is preserved from one iteration to the next. That is, just before the line
`
``
1543
+
``
1544
`+
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
`
``
1545
+
``
1546
`+
is executed in the above code, we know that
`
``
1547
+
``
1548
`+
(1) (a - 1)*2 < (n >> 2(c - e)) < (a + 1)**2.
`
``
1549
+
``
1550
`` +
(since e
is always the value of d
from the previous iteration). We must
``
``
1551
`+
prove that after that line is executed, we have
`
``
1552
+
``
1553
`+
(a - 1)*2 < (n >> 2(c - d)) < (a + 1)**2
`
``
1554
+
``
1555
`` +
To faciliate the proof, we make some changes of notation. Write m
for
``
``
1556
`` +
n >> 2*(c-d)
, and write b
for the new value of a
, so
``
``
1557
+
``
1558
`+
b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
`
``
1559
+
``
1560
`+
or equivalently:
`
``
1561
+
``
1562
`+
(2) b = (a << d - e - 1) + (m >> d - e + 1) // a
`
``
1563
+
``
1564
`+
Then we can rewrite (1) as:
`
``
1565
+
``
1566
`+
(3) (a - 1)*2 < (m >> 2(d - e)) < (a + 1)**2
`
``
1567
+
``
1568
`+
and we must show that (b - 1)**2 < m < (b + 1)**2.
`
``
1569
+
``
1570
`` +
From this point on, we switch to mathematical notation, so /
means exact
``
``
1571
`` +
division rather than integer division and ^
is used for exponentiation. We
``
``
1572
`` +
use the √
symbol for the exact square root. In (3), we can remove the
``
``
1573
`+
implicit floor operation to give:
`
``
1574
+
``
1575
`+
(4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2
`
``
1576
+
``
1577
`` +
Taking square roots throughout (4), scaling by 2^(d-e)
, and rearranging gives
``
``
1578
+
``
1579
`+
(5) 0 <= | 2^(d-e)a - √m | < 2^(d-e)
`
``
1580
+
``
1581
`` +
Squaring and dividing through by 2^(d-e+1) a
gives
``
``
1582
+
``
1583
`+
(6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a
`
``
1584
+
``
1585
`` +
We'll show below that 2^(d-e-1) <= a
. Given that, we can replace the
``
``
1586
`` +
right-hand side of (6) with 1
, and now replacing the central
``
``
1587
`` +
term m / (2^(d-e+1) a)
with its floor in (6) gives
``
``
1588
+
``
1589
`+
(7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1
`
``
1590
+
``
1591
`+
Or equivalently, from (2):
`
``
1592
+
``
1593
`+
(7) -1 < b - √m < 1
`
``
1594
+
``
1595
`` +
and rearranging gives that (b-1)^2 < m < (b+1)^2
, which is what we needed
``
``
1596
`+
to prove.
`
``
1597
+
``
1598
`` +
We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=
``
``
1599
`` +
athat was used to get line (7) above. From the definition of
c`, we have
``
``
1600
`` +
4^c <= n
, which implies
``
``
1601
+
``
1602
`+
(8) 4^d <= m
`
``
1603
+
``
1604
`` +
also, since e == d >> 1
, d
is at most 2e + 1
, from which it follows
``
``
1605
`` +
that 2d - 2e - 1 <= d
and hence that
``
``
1606
+
``
1607
`+
(9) 4^(2d - 2e - 1) <= m
`
``
1608
+
``
1609
`` +
Dividing both sides by 4^(d - e)
gives
``
``
1610
+
``
1611
`+
(10) 4^(d - e - 1) <= m / 4^(d - e)
`
``
1612
+
``
1613
`` +
But we know from (4) that m / 4^(d-e) < (a + 1)^2
, hence
``
``
1614
+
``
1615
`+
(11) 4^(d - e - 1) < (a + 1)^2
`
``
1616
+
``
1617
`` +
Now taking square roots of both sides and observing that both 2^(d-e-1)
and
``
``
1618
`` +
a
are integers gives 2^(d - e - 1) <= a
, which is what we needed. This
``
``
1619
`+
completes the proof sketch.
`
``
1620
+
``
1621
`+
*/
`
``
1622
+
``
1623
`+
/*[clinic input]
`
``
1624
`+
math.isqrt
`
``
1625
+
``
1626
`+
n: object
`
``
1627
`+
/
`
``
1628
+
``
1629
`+
Return the integer part of the square root of the input.
`
``
1630
`+
[clinic start generated code]*/
`
``
1631
+
``
1632
`+
static PyObject *
`
``
1633
`+
math_isqrt(PyObject *module, PyObject *n)
`
``
1634
`+
/[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]/
`
``
1635
`+
{
`
``
1636
`+
int a_too_large, s;
`
``
1637
`+
size_t c, d;
`
``
1638
`+
PyObject *a = NULL, *b;
`
``
1639
+
``
1640
`+
n = PyNumber_Index(n);
`
``
1641
`+
if (n == NULL) {
`
``
1642
`+
return NULL;
`
``
1643
`+
}
`
``
1644
+
``
1645
`+
if (_PyLong_Sign(n) < 0) {
`
``
1646
`+
PyErr_SetString(
`
``
1647
`+
PyExc_ValueError,
`
``
1648
`+
"isqrt() argument must be nonnegative");
`
``
1649
`+
goto error;
`
``
1650
`+
}
`
``
1651
`+
if (_PyLong_Sign(n) == 0) {
`
``
1652
`+
Py_DECREF(n);
`
``
1653
`+
return PyLong_FromLong(0);
`
``
1654
`+
}
`
``
1655
+
``
1656
`+
c = _PyLong_NumBits(n);
`
``
1657
`+
if (c == (size_t)(-1)) {
`
``
1658
`+
goto error;
`
``
1659
`+
}
`
``
1660
`+
c = (c - 1U) / 2U;
`
``
1661
+
``
1662
`+
/* s = c.bit_length() */
`
``
1663
`+
s = 0;
`
``
1664
`+
while ((c >> s) > 0) {
`
``
1665
`+
++s;
`
``
1666
`+
}
`
``
1667
+
``
1668
`+
a = PyLong_FromLong(1);
`
``
1669
`+
if (a == NULL) {
`
``
1670
`+
goto error;
`
``
1671
`+
}
`
``
1672
`+
d = 0;
`
``
1673
`+
while (--s >= 0) {
`
``
1674
`+
PyObject *q, *shift;
`
``
1675
`+
size_t e = d;
`
``
1676
+
``
1677
`+
d = c >> s;
`
``
1678
+
``
1679
`+
/* q = (n >> 2*c - e - d + 1) // a */
`
``
1680
`+
shift = PyLong_FromSize_t(2U*c - d - e + 1U);
`
``
1681
`+
if (shift == NULL) {
`
``
1682
`+
goto error;
`
``
1683
`+
}
`
``
1684
`+
q = PyNumber_Rshift(n, shift);
`
``
1685
`+
Py_DECREF(shift);
`
``
1686
`+
if (q == NULL) {
`
``
1687
`+
goto error;
`
``
1688
`+
}
`
``
1689
`+
Py_SETREF(q, PyNumber_FloorDivide(q, a));
`
``
1690
`+
if (q == NULL) {
`
``
1691
`+
goto error;
`
``
1692
`+
}
`
``
1693
+
``
1694
`+
/* a = (a << d - 1 - e) + q */
`
``
1695
`+
shift = PyLong_FromSize_t(d - 1U - e);
`
``
1696
`+
if (shift == NULL) {
`
``
1697
`+
Py_DECREF(q);
`
``
1698
`+
goto error;
`
``
1699
`+
}
`
``
1700
`+
Py_SETREF(a, PyNumber_Lshift(a, shift));
`
``
1701
`+
Py_DECREF(shift);
`
``
1702
`+
if (a == NULL) {
`
``
1703
`+
Py_DECREF(q);
`
``
1704
`+
goto error;
`
``
1705
`+
}
`
``
1706
`+
Py_SETREF(a, PyNumber_Add(a, q));
`
``
1707
`+
Py_DECREF(q);
`
``
1708
`+
if (a == NULL) {
`
``
1709
`+
goto error;
`
``
1710
`+
}
`
``
1711
`+
}
`
``
1712
+
``
1713
`+
/* The correct result is either a or a - 1. Figure out which, and
`
``
1714
`+
decrement a if necessary. */
`
``
1715
+
``
1716
`+
/* a_too_large = n < a * a */
`
``
1717
`+
b = PyNumber_Multiply(a, a);
`
``
1718
`+
if (b == NULL) {
`
``
1719
`+
goto error;
`
``
1720
`+
}
`
``
1721
`+
a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
`
``
1722
`+
Py_DECREF(b);
`
``
1723
`+
if (a_too_large == -1) {
`
``
1724
`+
goto error;
`
``
1725
`+
}
`
``
1726
+
``
1727
`+
if (a_too_large) {
`
``
1728
`+
Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));
`
``
1729
`+
}
`
``
1730
`+
Py_DECREF(n);
`
``
1731
`+
return a;
`
``
1732
+
``
1733
`+
error:
`
``
1734
`+
Py_XDECREF(a);
`
``
1735
`+
Py_DECREF(n);
`
``
1736
`+
return NULL;
`
``
1737
`+
}
`
``
1738
+
1479
1739
`/* Divide-and-conquer factorial algorithm
`
1480
1740
` *
`
1481
1741
` * Based on the formula and pseudo-code provided at:
`
`@@ -2737,6 +2997,7 @@ static PyMethodDef math_methods[] = {
`
2737
2997
`MATH_ISFINITE_METHODDEF
`
2738
2998
`MATH_ISINF_METHODDEF
`
2739
2999
`MATH_ISNAN_METHODDEF
`
``
3000
`+
MATH_ISQRT_METHODDEF
`
2740
3001
`MATH_LDEXP_METHODDEF
`
2741
3002
` {"lgamma", math_lgamma, METH_O, math_lgamma_doc},
`
2742
3003
`MATH_LOG_METHODDEF
`