[SPARK-19019] [PYTHON] Fix hijacked collections.namedtuple and port… · apache/spark@20e6280 (original) (raw)
`@@ -43,6 +43,7 @@
`
43
43
`from future import print_function
`
44
44
``
45
45
`import operator
`
``
46
`+
import opcode
`
46
47
`import os
`
47
48
`import io
`
48
49
`import pickle
`
`@@ -53,6 +54,8 @@
`
53
54
`import itertools
`
54
55
`import dis
`
55
56
`import traceback
`
``
57
`+
import weakref
`
``
58
+
56
59
``
57
60
`if sys.version < '3':
`
58
61
`from pickle import Pickler
`
`@@ -68,10 +71,10 @@
`
68
71
`PY3 = True
`
69
72
``
70
73
`#relevant opcodes
`
71
``
`-
STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
`
72
``
`-
DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
`
73
``
`-
LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
`
74
``
`-
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
`
``
74
`+
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
`
``
75
`+
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
`
``
76
`+
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
`
``
77
`+
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
`
75
78
`HAVE_ARGUMENT = dis.HAVE_ARGUMENT
`
76
79
`EXTENDED_ARG = dis.EXTENDED_ARG
`
77
80
``
`@@ -90,6 +93,43 @@ def _builtin_type(name):
`
90
93
`return getattr(types, name)
`
91
94
``
92
95
``
``
96
`+
if sys.version_info < (3, 4):
`
``
97
`+
def _walk_global_ops(code):
`
``
98
`+
"""
`
``
99
`+
Yield (opcode, argument number) tuples for all
`
``
100
`+
global-referencing instructions in code.
`
``
101
`+
"""
`
``
102
`+
code = getattr(code, 'co_code', b'')
`
``
103
`+
if not PY3:
`
``
104
`+
code = map(ord, code)
`
``
105
+
``
106
`+
n = len(code)
`
``
107
`+
i = 0
`
``
108
`+
extended_arg = 0
`
``
109
`+
while i < n:
`
``
110
`+
op = code[i]
`
``
111
`+
i += 1
`
``
112
`+
if op >= HAVE_ARGUMENT:
`
``
113
`+
oparg = code[i] + code[i + 1] * 256 + extended_arg
`
``
114
`+
extended_arg = 0
`
``
115
`+
i += 2
`
``
116
`+
if op == EXTENDED_ARG:
`
``
117
`+
extended_arg = oparg * 65536
`
``
118
`+
if op in GLOBAL_OPS:
`
``
119
`+
yield op, oparg
`
``
120
+
``
121
`+
else:
`
``
122
`+
def _walk_global_ops(code):
`
``
123
`+
"""
`
``
124
`+
Yield (opcode, argument number) tuples for all
`
``
125
`+
global-referencing instructions in code.
`
``
126
`+
"""
`
``
127
`+
for instr in dis.get_instructions(code):
`
``
128
`+
op = instr.opcode
`
``
129
`+
if op in GLOBAL_OPS:
`
``
130
`+
yield op, instr.arg
`
``
131
+
``
132
+
93
133
`class CloudPickler(Pickler):
`
94
134
``
95
135
`dispatch = Pickler.dispatch.copy()
`
`@@ -260,38 +300,34 @@ def save_function_tuple(self, func):
`
260
300
`write(pickle.TUPLE)
`
261
301
`write(pickle.REDUCE) # applies _fill_function on the tuple
`
262
302
``
263
``
`-
@staticmethod
`
264
``
`-
def extract_code_globals(co):
`
``
303
`+
_extract_code_globals_cache = (
`
``
304
`+
weakref.WeakKeyDictionary()
`
``
305
`+
if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info")
`
``
306
`+
else {})
`
``
307
+
``
308
`+
@classmethod
`
``
309
`+
def extract_code_globals(cls, co):
`
265
310
`"""
`
266
311
` Find all globals names read or written to by codeblock co
`
267
312
` """
`
268
``
`-
code = co.co_code
`
269
``
`-
if not PY3:
`
270
``
`-
code = [ord(c) for c in code]
`
271
``
`-
names = co.co_names
`
272
``
`-
out_names = set()
`
273
``
-
274
``
`-
n = len(code)
`
275
``
`-
i = 0
`
276
``
`-
extended_arg = 0
`
277
``
`-
while i < n:
`
278
``
`-
op = code[i]
`
``
313
`+
out_names = cls._extract_code_globals_cache.get(co)
`
``
314
`+
if out_names is None:
`
``
315
`+
try:
`
``
316
`+
names = co.co_names
`
``
317
`+
except AttributeError:
`
``
318
`+
PyPy "builtin-code" object
`
``
319
`+
out_names = set()
`
``
320
`+
else:
`
``
321
`+
out_names = set(names[oparg]
`
``
322
`+
for op, oparg in _walk_global_ops(co))
`
279
323
``
280
``
`-
i += 1
`
281
``
`-
if op >= HAVE_ARGUMENT:
`
282
``
`-
oparg = code[i] + code[i+1] * 256 + extended_arg
`
283
``
`-
extended_arg = 0
`
284
``
`-
i += 2
`
285
``
`-
if op == EXTENDED_ARG:
`
286
``
`-
extended_arg = oparg*65536
`
287
``
`-
if op in GLOBAL_OPS:
`
288
``
`-
out_names.add(names[oparg])
`
``
324
`+
see if nested function have any global refs
`
``
325
`+
if co.co_consts:
`
``
326
`+
for const in co.co_consts:
`
``
327
`+
if type(const) is types.CodeType:
`
``
328
`+
out_names |= cls.extract_code_globals(const)
`
289
329
``
290
``
`-
see if nested function have any global refs
`
291
``
`-
if co.co_consts:
`
292
``
`-
for const in co.co_consts:
`
293
``
`-
if type(const) is types.CodeType:
`
294
``
`-
out_names |= CloudPickler.extract_code_globals(const)
`
``
330
`+
cls._extract_code_globals_cache[co] = out_names
`
295
331
``
296
332
`return out_names
`
297
333
``