bpo-36887: add math.isqrt (GH-13244) · python/cpython@73934b9 (original) (raw)

`@@ -1476,6 +1476,266 @@ count_set_bits(unsigned long n)

`

1476

1476

`return count;

`

1477

1477

`}

`

1478

1478

``

``

1479

`+

/* Integer square root

`

``

1480

+

``

1481

`` +

Given a nonnegative integer n, we want to compute the largest integer

``

``

1482

`` +

a for which a * a <= n, or equivalently the integer part of the exact

``

``

1483

`` +

square root of n.

``

``

1484

+

``

1485

`+

We use an adaptive-precision pure-integer version of Newton's iteration. Given

`

``

1486

`` +

a positive integer n, the algorithm produces at each iteration an integer

``

``

1487

`` +

approximation a to the square root of n >> s for some even integer s,

``

``

1488

`` +

with s decreasing as the iterations progress. On the final iteration, s is

``

``

1489

`` +

zero and we have an approximation to the square root of n itself.

``

``

1490

+

``

1491

`` +

At every step, the approximation a is strictly within 1.0 of the true square

``

``

1492

`+

root, so we have

`

``

1493

+

``

1494

`+

(a - 1)**2 < (n >> s) < (a + 1)**2

`

``

1495

+

``

1496

`+

After the final iteration, a check-and-correct step is needed to determine

`

``

1497

`` +

whether a or a - 1 gives the desired integer square root of n.

``

``

1498

+

``

1499

`+

The algorithm is remarkable in its simplicity. There's no need for a

`

``

1500

`+

per-iteration check-and-correct step, and termination is straightforward: the

`

``

1501

`` +

number of iterations is known in advance (it's exactly floor(log2(log2(n)))

``

``

1502

`` +

for n > 1). The only tricky part of the correctness proof is in establishing

``

``

1503

`` +

that the bound (a - 1)**2 < (n >> s) < (a + 1)**2 is maintained from one

``

``

1504

`+

iteration to the next. A sketch of the proof of this is given below.

`

``

1505

+

``

1506

`+

In addition to the proof sketch, a formal, computer-verified proof

`

``

1507

`+

of correctness (using Lean) of an equivalent recursive algorithm can be found

`

``

1508

`+

here:

`

``

1509

+

``

1510

`+

https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean

`

``

1511

+

``

1512

+

``

1513

`+

Here's Python code equivalent to the C implementation below:

`

``

1514

+

``

1515

`+

def isqrt(n):

`

``

1516

`+

"""

`

``

1517

`+

Return the integer part of the square root of the input.

`

``

1518

`+

"""

`

``

1519

`+

n = operator.index(n)

`

``

1520

+

``

1521

`+

if n < 0:

`

``

1522

`+

raise ValueError("isqrt() argument must be nonnegative")

`

``

1523

`+

if n == 0:

`

``

1524

`+

return 0

`

``

1525

+

``

1526

`+

c = (n.bit_length() - 1) // 2

`

``

1527

`+

a = 1

`

``

1528

`+

d = 0

`

``

1529

`+

for s in reversed(range(c.bit_length())):

`

``

1530

`+

e = d

`

``

1531

`+

d = c >> s

`

``

1532

`+

a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a

`

``

1533

`+

assert (a-1)*2 < n >> 2(c - d) < (a+1)**2

`

``

1534

+

``

1535

`+

return a - (a*a > n)

`

``

1536

+

``

1537

+

``

1538

`+

Sketch of proof of correctness

`

``

1539

`+


`

``

1540

+

``

1541

`+

The delicate part of the correctness proof is showing that the loop invariant

`

``

1542

`+

is preserved from one iteration to the next. That is, just before the line

`

``

1543

+

``

1544

`+

a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a

`

``

1545

+

``

1546

`+

is executed in the above code, we know that

`

``

1547

+

``

1548

`+

(1) (a - 1)*2 < (n >> 2(c - e)) < (a + 1)**2.

`

``

1549

+

``

1550

`` +

(since e is always the value of d from the previous iteration). We must

``

``

1551

`+

prove that after that line is executed, we have

`

``

1552

+

``

1553

`+

(a - 1)*2 < (n >> 2(c - d)) < (a + 1)**2

`

``

1554

+

``

1555

`` +

To faciliate the proof, we make some changes of notation. Write m for

``

``

1556

`` +

n >> 2*(c-d), and write b for the new value of a, so

``

``

1557

+

``

1558

`+

b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a

`

``

1559

+

``

1560

`+

or equivalently:

`

``

1561

+

``

1562

`+

(2) b = (a << d - e - 1) + (m >> d - e + 1) // a

`

``

1563

+

``

1564

`+

Then we can rewrite (1) as:

`

``

1565

+

``

1566

`+

(3) (a - 1)*2 < (m >> 2(d - e)) < (a + 1)**2

`

``

1567

+

``

1568

`+

and we must show that (b - 1)**2 < m < (b + 1)**2.

`

``

1569

+

``

1570

`` +

From this point on, we switch to mathematical notation, so / means exact

``

``

1571

`` +

division rather than integer division and ^ is used for exponentiation. We

``

``

1572

`` +

use the symbol for the exact square root. In (3), we can remove the

``

``

1573

`+

implicit floor operation to give:

`

``

1574

+

``

1575

`+

(4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2

`

``

1576

+

``

1577

`` +

Taking square roots throughout (4), scaling by 2^(d-e), and rearranging gives

``

``

1578

+

``

1579

`+

(5) 0 <= | 2^(d-e)a - √m | < 2^(d-e)

`

``

1580

+

``

1581

`` +

Squaring and dividing through by 2^(d-e+1) a gives

``

``

1582

+

``

1583

`+

(6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a

`

``

1584

+

``

1585

`` +

We'll show below that 2^(d-e-1) <= a. Given that, we can replace the

``

``

1586

`` +

right-hand side of (6) with 1, and now replacing the central

``

``

1587

`` +

term m / (2^(d-e+1) a) with its floor in (6) gives

``

``

1588

+

``

1589

`+

(7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1

`

``

1590

+

``

1591

`+

Or equivalently, from (2):

`

``

1592

+

``

1593

`+

(7) -1 < b - √m < 1

`

``

1594

+

``

1595

`` +

and rearranging gives that (b-1)^2 < m < (b+1)^2, which is what we needed

``

``

1596

`+

to prove.

`

``

1597

+

``

1598

`` +

We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=

``

``

1599

`` +

athat was used to get line (7) above. From the definition ofc`, we have

``

``

1600

`` +

4^c <= n, which implies

``

``

1601

+

``

1602

`+

(8) 4^d <= m

`

``

1603

+

``

1604

`` +

also, since e == d >> 1, d is at most 2e + 1, from which it follows

``

``

1605

`` +

that 2d - 2e - 1 <= d and hence that

``

``

1606

+

``

1607

`+

(9) 4^(2d - 2e - 1) <= m

`

``

1608

+

``

1609

`` +

Dividing both sides by 4^(d - e) gives

``

``

1610

+

``

1611

`+

(10) 4^(d - e - 1) <= m / 4^(d - e)

`

``

1612

+

``

1613

`` +

But we know from (4) that m / 4^(d-e) < (a + 1)^2, hence

``

``

1614

+

``

1615

`+

(11) 4^(d - e - 1) < (a + 1)^2

`

``

1616

+

``

1617

`` +

Now taking square roots of both sides and observing that both 2^(d-e-1) and

``

``

1618

`` +

a are integers gives 2^(d - e - 1) <= a, which is what we needed. This

``

``

1619

`+

completes the proof sketch.

`

``

1620

+

``

1621

`+

*/

`

``

1622

+

``

1623

`+

/*[clinic input]

`

``

1624

`+

math.isqrt

`

``

1625

+

``

1626

`+

n: object

`

``

1627

`+

/

`

``

1628

+

``

1629

`+

Return the integer part of the square root of the input.

`

``

1630

`+

[clinic start generated code]*/

`

``

1631

+

``

1632

`+

static PyObject *

`

``

1633

`+

math_isqrt(PyObject *module, PyObject *n)

`

``

1634

`+

/[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]/

`

``

1635

`+

{

`

``

1636

`+

int a_too_large, s;

`

``

1637

`+

size_t c, d;

`

``

1638

`+

PyObject *a = NULL, *b;

`

``

1639

+

``

1640

`+

n = PyNumber_Index(n);

`

``

1641

`+

if (n == NULL) {

`

``

1642

`+

return NULL;

`

``

1643

`+

}

`

``

1644

+

``

1645

`+

if (_PyLong_Sign(n) < 0) {

`

``

1646

`+

PyErr_SetString(

`

``

1647

`+

PyExc_ValueError,

`

``

1648

`+

"isqrt() argument must be nonnegative");

`

``

1649

`+

goto error;

`

``

1650

`+

}

`

``

1651

`+

if (_PyLong_Sign(n) == 0) {

`

``

1652

`+

Py_DECREF(n);

`

``

1653

`+

return PyLong_FromLong(0);

`

``

1654

`+

}

`

``

1655

+

``

1656

`+

c = _PyLong_NumBits(n);

`

``

1657

`+

if (c == (size_t)(-1)) {

`

``

1658

`+

goto error;

`

``

1659

`+

}

`

``

1660

`+

c = (c - 1U) / 2U;

`

``

1661

+

``

1662

`+

/* s = c.bit_length() */

`

``

1663

`+

s = 0;

`

``

1664

`+

while ((c >> s) > 0) {

`

``

1665

`+

++s;

`

``

1666

`+

}

`

``

1667

+

``

1668

`+

a = PyLong_FromLong(1);

`

``

1669

`+

if (a == NULL) {

`

``

1670

`+

goto error;

`

``

1671

`+

}

`

``

1672

`+

d = 0;

`

``

1673

`+

while (--s >= 0) {

`

``

1674

`+

PyObject *q, *shift;

`

``

1675

`+

size_t e = d;

`

``

1676

+

``

1677

`+

d = c >> s;

`

``

1678

+

``

1679

`+

/* q = (n >> 2*c - e - d + 1) // a */

`

``

1680

`+

shift = PyLong_FromSize_t(2U*c - d - e + 1U);

`

``

1681

`+

if (shift == NULL) {

`

``

1682

`+

goto error;

`

``

1683

`+

}

`

``

1684

`+

q = PyNumber_Rshift(n, shift);

`

``

1685

`+

Py_DECREF(shift);

`

``

1686

`+

if (q == NULL) {

`

``

1687

`+

goto error;

`

``

1688

`+

}

`

``

1689

`+

Py_SETREF(q, PyNumber_FloorDivide(q, a));

`

``

1690

`+

if (q == NULL) {

`

``

1691

`+

goto error;

`

``

1692

`+

}

`

``

1693

+

``

1694

`+

/* a = (a << d - 1 - e) + q */

`

``

1695

`+

shift = PyLong_FromSize_t(d - 1U - e);

`

``

1696

`+

if (shift == NULL) {

`

``

1697

`+

Py_DECREF(q);

`

``

1698

`+

goto error;

`

``

1699

`+

}

`

``

1700

`+

Py_SETREF(a, PyNumber_Lshift(a, shift));

`

``

1701

`+

Py_DECREF(shift);

`

``

1702

`+

if (a == NULL) {

`

``

1703

`+

Py_DECREF(q);

`

``

1704

`+

goto error;

`

``

1705

`+

}

`

``

1706

`+

Py_SETREF(a, PyNumber_Add(a, q));

`

``

1707

`+

Py_DECREF(q);

`

``

1708

`+

if (a == NULL) {

`

``

1709

`+

goto error;

`

``

1710

`+

}

`

``

1711

`+

}

`

``

1712

+

``

1713

`+

/* The correct result is either a or a - 1. Figure out which, and

`

``

1714

`+

decrement a if necessary. */

`

``

1715

+

``

1716

`+

/* a_too_large = n < a * a */

`

``

1717

`+

b = PyNumber_Multiply(a, a);

`

``

1718

`+

if (b == NULL) {

`

``

1719

`+

goto error;

`

``

1720

`+

}

`

``

1721

`+

a_too_large = PyObject_RichCompareBool(n, b, Py_LT);

`

``

1722

`+

Py_DECREF(b);

`

``

1723

`+

if (a_too_large == -1) {

`

``

1724

`+

goto error;

`

``

1725

`+

}

`

``

1726

+

``

1727

`+

if (a_too_large) {

`

``

1728

`+

Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));

`

``

1729

`+

}

`

``

1730

`+

Py_DECREF(n);

`

``

1731

`+

return a;

`

``

1732

+

``

1733

`+

error:

`

``

1734

`+

Py_XDECREF(a);

`

``

1735

`+

Py_DECREF(n);

`

``

1736

`+

return NULL;

`

``

1737

`+

}

`

``

1738

+

1479

1739

`/* Divide-and-conquer factorial algorithm

`

1480

1740

` *

`

1481

1741

` * Based on the formula and pseudo-code provided at:

`

`@@ -2737,6 +2997,7 @@ static PyMethodDef math_methods[] = {

`

2737

2997

`MATH_ISFINITE_METHODDEF

`

2738

2998

`MATH_ISINF_METHODDEF

`

2739

2999

`MATH_ISNAN_METHODDEF

`

``

3000

`+

MATH_ISQRT_METHODDEF

`

2740

3001

`MATH_LDEXP_METHODDEF

`

2741

3002

` {"lgamma", math_lgamma, METH_O, math_lgamma_doc},

`

2742

3003

`MATH_LOG_METHODDEF

`