BUG: Perform i8 conversion for datetimelike IntervalTree queries · pandas-dev/pandas@02f19f8 (original) (raw)

1

1

`from future import division

`

2

2

``

``

3

`+

from itertools import permutations

`

3

4

`import pytest

`

4

5

`import numpy as np

`

``

6

`+

import re

`

5

7

`from pandas import (

`

6

8

`Interval, IntervalIndex, Index, isna, notna, interval_range, Timestamp,

`

7

9

`Timedelta, date_range, timedelta_range)

`

`@@ -498,6 +500,48 @@ def test_get_loc_length_one(self, item, closed):

`

498

500

`result = index.get_loc(item)

`

499

501

`assert result == 0

`

500

502

``

``

503

`+

Make consistent with test_interval_new.py (see #16316, #16386)

`

``

504

`+

@pytest.mark.parametrize('breaks', [

`

``

505

`+

date_range('20180101', periods=4),

`

``

506

`+

date_range('20180101', periods=4, tz='US/Eastern'),

`

``

507

`+

timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))

`

``

508

`+

def test_get_loc_datetimelike_nonoverlapping(self, breaks):

`

``

509

`+

GH 20636

`

``

510

`+

nonoverlapping = IntervalIndex method and no i8 conversion

`

``

511

`+

index = IntervalIndex.from_breaks(breaks)

`

``

512

+

``

513

`+

value = index[0].mid

`

``

514

`+

result = index.get_loc(value)

`

``

515

`+

expected = 0

`

``

516

`+

assert result == expected

`

``

517

+

``

518

`+

interval = Interval(index[0].left, index[1].right)

`

``

519

`+

result = index.get_loc(interval)

`

``

520

`+

expected = slice(0, 2)

`

``

521

`+

assert result == expected

`

``

522

+

``

523

`+

Make consistent with test_interval_new.py (see #16316, #16386)

`

``

524

`+

@pytest.mark.parametrize('arrays', [

`

``

525

`+

(date_range('20180101', periods=4), date_range('20180103', periods=4)),

`

``

526

`+

(date_range('20180101', periods=4, tz='US/Eastern'),

`

``

527

`+

date_range('20180103', periods=4, tz='US/Eastern')),

`

``

528

`+

(timedelta_range('0 days', periods=4),

`

``

529

`+

timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))

`

``

530

`+

def test_get_loc_datetimelike_overlapping(self, arrays):

`

``

531

`+

GH 20636

`

``

532

`+

overlapping = IntervalTree method with i8 conversion

`

``

533

`+

index = IntervalIndex.from_arrays(*arrays)

`

``

534

+

``

535

`+

value = index[0].mid + Timedelta('12 hours')

`

``

536

`+

result = np.sort(index.get_loc(value))

`

``

537

`+

expected = np.array([0, 1], dtype='int64')

`

``

538

`+

assert tm.assert_numpy_array_equal(result, expected)

`

``

539

+

``

540

`+

interval = Interval(index[0].left, index[1].right)

`

``

541

`+

result = np.sort(index.get_loc(interval))

`

``

542

`+

expected = np.array([0, 1, 2], dtype='int64')

`

``

543

`+

assert tm.assert_numpy_array_equal(result, expected)

`

``

544

+

501

545

`# To be removed, replaced by test_interval_new.py (see #16316, #16386)

`

502

546

`def test_get_indexer(self):

`

503

547

`actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3])

`

`@@ -555,6 +599,97 @@ def test_get_indexer_length_one(self, item, closed):

`

555

599

`expected = np.array([0] * len(item), dtype='intp')

`

556

600

`tm.assert_numpy_array_equal(result, expected)

`

557

601

``

``

602

`+

Make consistent with test_interval_new.py (see #16316, #16386)

`

``

603

`+

@pytest.mark.parametrize('arrays', [

`

``

604

`+

(date_range('20180101', periods=4), date_range('20180103', periods=4)),

`

``

605

`+

(date_range('20180101', periods=4, tz='US/Eastern'),

`

``

606

`+

date_range('20180103', periods=4, tz='US/Eastern')),

`

``

607

`+

(timedelta_range('0 days', periods=4),

`

``

608

`+

timedelta_range('2 days', periods=4))], ids=lambda x: str(x[0].dtype))

`

``

609

`+

def test_get_reindexer_datetimelike(self, arrays):

`

``

610

`+

GH 20636

`

``

611

`+

index = IntervalIndex.from_arrays(*arrays)

`

``

612

`+

tuples = [(index[0].left, index[0].left + pd.Timedelta('12H')),

`

``

613

`+

(index[-1].right - pd.Timedelta('12H'), index[-1].right)]

`

``

614

`+

target = IntervalIndex.from_tuples(tuples)

`

``

615

+

``

616

`+

result = index._get_reindexer(target)

`

``

617

`+

expected = np.array([0, 3], dtype='int64')

`

``

618

`+

tm.assert_numpy_array_equal(result, expected)

`

``

619

+

``

620

`+

@pytest.mark.parametrize('breaks', [

`

``

621

`+

date_range('20180101', periods=4),

`

``

622

`+

date_range('20180101', periods=4, tz='US/Eastern'),

`

``

623

`+

timedelta_range('0 days', periods=4)], ids=lambda x: str(x.dtype))

`

``

624

`+

def test_maybe_convert_i8(self, breaks):

`

``

625

`+

GH 20636

`

``

626

`+

index = IntervalIndex.from_breaks(breaks)

`

``

627

+

``

628

`+

intervalindex

`

``

629

`+

result = index._maybe_convert_i8(index)

`

``

630

`+

expected = IntervalIndex.from_breaks(breaks.asi8)

`

``

631

`+

tm.assert_index_equal(result, expected)

`

``

632

+

``

633

`+

interval

`

``

634

`+

interval = Interval(breaks[0], breaks[1])

`

``

635

`+

result = index._maybe_convert_i8(interval)

`

``

636

`+

expected = Interval(breaks[0].value, breaks[1].value)

`

``

637

`+

assert result == expected

`

``

638

+

``

639

`+

datetimelike index

`

``

640

`+

result = index._maybe_convert_i8(breaks)

`

``

641

`+

expected = Index(breaks.asi8)

`

``

642

`+

tm.assert_index_equal(result, expected)

`

``

643

+

``

644

`+

datetimelike scalar

`

``

645

`+

result = index._maybe_convert_i8(breaks[0])

`

``

646

`+

expected = breaks[0].value

`

``

647

`+

assert result == expected

`

``

648

+

``

649

`+

list-like of datetimelike scalars

`

``

650

`+

result = index._maybe_convert_i8(list(breaks))

`

``

651

`+

expected = Index(breaks.asi8)

`

``

652

`+

tm.assert_index_equal(result, expected)

`

``

653

+

``

654

`+

@pytest.mark.parametrize('breaks', [

`

``

655

`+

np.arange(5, dtype='int64'),

`

``

656

`+

np.arange(5, dtype='float64')], ids=lambda x: str(x.dtype))

`

``

657

`+

def test_maybe_convert_i8_numeric(self, breaks):

`

``

658

`+

GH 20636

`

``

659

`+

index = IntervalIndex.from_breaks(breaks)

`

``

660

`+

numeric_keys = [

`

``

661

`+

IntervalIndex.from_breaks(breaks),

`

``

662

`+

Interval(breaks[0], breaks[1]),

`

``

663

`+

breaks,

`

``

664

`+

breaks[0],

`

``

665

`+

list(breaks)]

`

``

666

+

``

667

`+

no conversion occurs for numeric

`

``

668

`+

for key in numeric_keys:

`

``

669

`+

result = index._maybe_convert_i8(key)

`

``

670

`+

assert result is key

`

``

671

+

``

672

`+

@pytest.mark.parametrize('breaks1, breaks2', permutations([

`

``

673

`+

date_range('20180101', periods=4),

`

``

674

`+

date_range('20180101', periods=4, tz='US/Eastern'),

`

``

675

`+

timedelta_range('0 days', periods=4)], 2), ids=lambda x: str(x.dtype))

`

``

676

`+

def test_maybe_convert_i8_errors(self, breaks1, breaks2):

`

``

677

`+

GH 20636

`

``

678

`+

index = IntervalIndex.from_breaks(breaks1)

`

``

679

`+

invalid_keys = [

`

``

680

`+

IntervalIndex.from_breaks(breaks2),

`

``

681

`+

Interval(breaks2[0], breaks2[1]),

`

``

682

`+

breaks2,

`

``

683

`+

breaks2[0],

`

``

684

`+

list(breaks2)]

`

``

685

+

``

686

`+

msg = ('Cannot index an IntervalIndex of subtype {dtype1} with '

`

``

687

`+

'values of dtype {dtype2}')

`

``

688

`+

msg = re.escape(msg.format(dtype1=breaks1.dtype, dtype2=breaks2.dtype))

`

``

689

`+

for key in invalid_keys:

`

``

690

`+

with tm.assert_raises_regex(ValueError, msg):

`

``

691

`+

index._maybe_convert_i8(key)

`

``

692

+

558

693

`# To be removed, replaced by test_interval_new.py (see #16316, #16386)

`

559

694

`def test_contains(self):

`

560

695

`# Only endpoints are valid.

`