ENH: add and register Arrow extension types for Period and Interval (… · pandas-dev/pandas@2198f51 (original) (raw)
``
1
`+
from distutils.version import LooseVersion
`
``
2
`+
import json
`
``
3
+
``
4
`+
import numpy as np
`
``
5
`+
import pyarrow
`
``
6
+
``
7
`+
from pandas.core.arrays.interval import _VALID_CLOSED
`
``
8
+
``
9
`+
_pyarrow_version_ge_015 = LooseVersion(pyarrow.version) >= LooseVersion("0.15")
`
``
10
+
``
11
+
``
12
`+
def pyarrow_array_to_numpy_and_mask(arr, dtype):
`
``
13
`+
"""
`
``
14
`+
Convert a primitive pyarrow.Array to a numpy array and boolean mask based
`
``
15
`+
on the buffers of the Array.
`
``
16
+
``
17
`+
Parameters
`
``
18
`+
`
``
19
`+
arr : pyarrow.Array
`
``
20
`+
dtype : numpy.dtype
`
``
21
+
``
22
`+
Returns
`
``
23
`+
`
``
24
`+
(data, mask)
`
``
25
`+
Tuple of two numpy arrays with the raw data (with specified dtype) and
`
``
26
`+
a boolean mask (validity mask, so False means missing)
`
``
27
`+
"""
`
``
28
`+
buflist = arr.buffers()
`
``
29
`+
data = np.frombuffer(buflist[1], dtype=dtype)[arr.offset : arr.offset + len(arr)]
`
``
30
`+
bitmask = buflist[0]
`
``
31
`+
if bitmask is not None:
`
``
32
`+
mask = pyarrow.BooleanArray.from_buffers(
`
``
33
`+
pyarrow.bool_(), len(arr), [None, bitmask]
`
``
34
`+
)
`
``
35
`+
mask = np.asarray(mask)
`
``
36
`+
else:
`
``
37
`+
mask = np.ones(len(arr), dtype=bool)
`
``
38
`+
return data, mask
`
``
39
+
``
40
+
``
41
`+
if _pyarrow_version_ge_015:
`
``
42
`+
the pyarrow extension types are only available for pyarrow 0.15+
`
``
43
+
``
44
`+
class ArrowPeriodType(pyarrow.ExtensionType):
`
``
45
`+
def init(self, freq):
`
``
46
`+
attributes need to be set first before calling
`
``
47
`+
super init (as that calls serialize)
`
``
48
`+
self._freq = freq
`
``
49
`+
pyarrow.ExtensionType.init(self, pyarrow.int64(), "pandas.period")
`
``
50
+
``
51
`+
@property
`
``
52
`+
def freq(self):
`
``
53
`+
return self._freq
`
``
54
+
``
55
`+
def arrow_ext_serialize(self):
`
``
56
`+
metadata = {"freq": self.freq}
`
``
57
`+
return json.dumps(metadata).encode()
`
``
58
+
``
59
`+
@classmethod
`
``
60
`+
def arrow_ext_deserialize(cls, storage_type, serialized):
`
``
61
`+
metadata = json.loads(serialized.decode())
`
``
62
`+
return ArrowPeriodType(metadata["freq"])
`
``
63
+
``
64
`+
def eq(self, other):
`
``
65
`+
if isinstance(other, pyarrow.BaseExtensionType):
`
``
66
`+
return type(self) == type(other) and self.freq == other.freq
`
``
67
`+
else:
`
``
68
`+
return NotImplemented
`
``
69
+
``
70
`+
def hash(self):
`
``
71
`+
return hash((str(self), self.freq))
`
``
72
+
``
73
`+
register the type with a dummy instance
`
``
74
`+
_period_type = ArrowPeriodType("D")
`
``
75
`+
pyarrow.register_extension_type(_period_type)
`
``
76
+
``
77
`+
class ArrowIntervalType(pyarrow.ExtensionType):
`
``
78
`+
def init(self, subtype, closed):
`
``
79
`+
attributes need to be set first before calling
`
``
80
`+
super init (as that calls serialize)
`
``
81
`+
assert closed in _VALID_CLOSED
`
``
82
`+
self._closed = closed
`
``
83
`+
if not isinstance(subtype, pyarrow.DataType):
`
``
84
`+
subtype = pyarrow.type_for_alias(str(subtype))
`
``
85
`+
self._subtype = subtype
`
``
86
+
``
87
`+
storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
`
``
88
`+
pyarrow.ExtensionType.init(self, storage_type, "pandas.interval")
`
``
89
+
``
90
`+
@property
`
``
91
`+
def subtype(self):
`
``
92
`+
return self._subtype
`
``
93
+
``
94
`+
@property
`
``
95
`+
def closed(self):
`
``
96
`+
return self._closed
`
``
97
+
``
98
`+
def arrow_ext_serialize(self):
`
``
99
`+
metadata = {"subtype": str(self.subtype), "closed": self.closed}
`
``
100
`+
return json.dumps(metadata).encode()
`
``
101
+
``
102
`+
@classmethod
`
``
103
`+
def arrow_ext_deserialize(cls, storage_type, serialized):
`
``
104
`+
metadata = json.loads(serialized.decode())
`
``
105
`+
subtype = pyarrow.type_for_alias(metadata["subtype"])
`
``
106
`+
closed = metadata["closed"]
`
``
107
`+
return ArrowIntervalType(subtype, closed)
`
``
108
+
``
109
`+
def eq(self, other):
`
``
110
`+
if isinstance(other, pyarrow.BaseExtensionType):
`
``
111
`+
return (
`
``
112
`+
type(self) == type(other)
`
``
113
`+
and self.subtype == other.subtype
`
``
114
`+
and self.closed == other.closed
`
``
115
`+
)
`
``
116
`+
else:
`
``
117
`+
return NotImplemented
`
``
118
+
``
119
`+
def hash(self):
`
``
120
`+
return hash((str(self), str(self.subtype), self.closed))
`
``
121
+
``
122
`+
register the type with a dummy instance
`
``
123
`+
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
`
``
124
`+
pyarrow.register_extension_type(_interval_type)
`