ENH: add and register Arrow extension types for Period and Interval (… · pandas-dev/pandas@2198f51 (original) (raw)

``

1

`+

from distutils.version import LooseVersion

`

``

2

`+

import json

`

``

3

+

``

4

`+

import numpy as np

`

``

5

`+

import pyarrow

`

``

6

+

``

7

`+

from pandas.core.arrays.interval import _VALID_CLOSED

`

``

8

+

``

9

`+

_pyarrow_version_ge_015 = LooseVersion(pyarrow.version) >= LooseVersion("0.15")

`

``

10

+

``

11

+

``

12

`+

def pyarrow_array_to_numpy_and_mask(arr, dtype):

`

``

13

`+

"""

`

``

14

`+

Convert a primitive pyarrow.Array to a numpy array and boolean mask based

`

``

15

`+

on the buffers of the Array.

`

``

16

+

``

17

`+

Parameters

`

``

18

`+


`

``

19

`+

arr : pyarrow.Array

`

``

20

`+

dtype : numpy.dtype

`

``

21

+

``

22

`+

Returns

`

``

23

`+


`

``

24

`+

(data, mask)

`

``

25

`+

Tuple of two numpy arrays with the raw data (with specified dtype) and

`

``

26

`+

a boolean mask (validity mask, so False means missing)

`

``

27

`+

"""

`

``

28

`+

buflist = arr.buffers()

`

``

29

`+

data = np.frombuffer(buflist[1], dtype=dtype)[arr.offset : arr.offset + len(arr)]

`

``

30

`+

bitmask = buflist[0]

`

``

31

`+

if bitmask is not None:

`

``

32

`+

mask = pyarrow.BooleanArray.from_buffers(

`

``

33

`+

pyarrow.bool_(), len(arr), [None, bitmask]

`

``

34

`+

)

`

``

35

`+

mask = np.asarray(mask)

`

``

36

`+

else:

`

``

37

`+

mask = np.ones(len(arr), dtype=bool)

`

``

38

`+

return data, mask

`

``

39

+

``

40

+

``

41

`+

if _pyarrow_version_ge_015:

`

``

42

`+

the pyarrow extension types are only available for pyarrow 0.15+

`

``

43

+

``

44

`+

class ArrowPeriodType(pyarrow.ExtensionType):

`

``

45

`+

def init(self, freq):

`

``

46

`+

attributes need to be set first before calling

`

``

47

`+

super init (as that calls serialize)

`

``

48

`+

self._freq = freq

`

``

49

`+

pyarrow.ExtensionType.init(self, pyarrow.int64(), "pandas.period")

`

``

50

+

``

51

`+

@property

`

``

52

`+

def freq(self):

`

``

53

`+

return self._freq

`

``

54

+

``

55

`+

def arrow_ext_serialize(self):

`

``

56

`+

metadata = {"freq": self.freq}

`

``

57

`+

return json.dumps(metadata).encode()

`

``

58

+

``

59

`+

@classmethod

`

``

60

`+

def arrow_ext_deserialize(cls, storage_type, serialized):

`

``

61

`+

metadata = json.loads(serialized.decode())

`

``

62

`+

return ArrowPeriodType(metadata["freq"])

`

``

63

+

``

64

`+

def eq(self, other):

`

``

65

`+

if isinstance(other, pyarrow.BaseExtensionType):

`

``

66

`+

return type(self) == type(other) and self.freq == other.freq

`

``

67

`+

else:

`

``

68

`+

return NotImplemented

`

``

69

+

``

70

`+

def hash(self):

`

``

71

`+

return hash((str(self), self.freq))

`

``

72

+

``

73

`+

register the type with a dummy instance

`

``

74

`+

_period_type = ArrowPeriodType("D")

`

``

75

`+

pyarrow.register_extension_type(_period_type)

`

``

76

+

``

77

`+

class ArrowIntervalType(pyarrow.ExtensionType):

`

``

78

`+

def init(self, subtype, closed):

`

``

79

`+

attributes need to be set first before calling

`

``

80

`+

super init (as that calls serialize)

`

``

81

`+

assert closed in _VALID_CLOSED

`

``

82

`+

self._closed = closed

`

``

83

`+

if not isinstance(subtype, pyarrow.DataType):

`

``

84

`+

subtype = pyarrow.type_for_alias(str(subtype))

`

``

85

`+

self._subtype = subtype

`

``

86

+

``

87

`+

storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])

`

``

88

`+

pyarrow.ExtensionType.init(self, storage_type, "pandas.interval")

`

``

89

+

``

90

`+

@property

`

``

91

`+

def subtype(self):

`

``

92

`+

return self._subtype

`

``

93

+

``

94

`+

@property

`

``

95

`+

def closed(self):

`

``

96

`+

return self._closed

`

``

97

+

``

98

`+

def arrow_ext_serialize(self):

`

``

99

`+

metadata = {"subtype": str(self.subtype), "closed": self.closed}

`

``

100

`+

return json.dumps(metadata).encode()

`

``

101

+

``

102

`+

@classmethod

`

``

103

`+

def arrow_ext_deserialize(cls, storage_type, serialized):

`

``

104

`+

metadata = json.loads(serialized.decode())

`

``

105

`+

subtype = pyarrow.type_for_alias(metadata["subtype"])

`

``

106

`+

closed = metadata["closed"]

`

``

107

`+

return ArrowIntervalType(subtype, closed)

`

``

108

+

``

109

`+

def eq(self, other):

`

``

110

`+

if isinstance(other, pyarrow.BaseExtensionType):

`

``

111

`+

return (

`

``

112

`+

type(self) == type(other)

`

``

113

`+

and self.subtype == other.subtype

`

``

114

`+

and self.closed == other.closed

`

``

115

`+

)

`

``

116

`+

else:

`

``

117

`+

return NotImplemented

`

``

118

+

``

119

`+

def hash(self):

`

``

120

`+

return hash((str(self), str(self.subtype), self.closed))

`

``

121

+

``

122

`+

register the type with a dummy instance

`

``

123

`+

_interval_type = ArrowIntervalType(pyarrow.int64(), "left")

`

``

124

`+

pyarrow.register_extension_type(_interval_type)

`