Implementing BART with CAR prior (original) (raw)
The full traceback using the sparse matrix is below. I’d also tried running the non-BART model using car_idata = pm.sample(200, tune=200, nuts_sampler=“nutpie”, chains=4, blas_cores=16) to see if I could get a speedup. This syntax doesn’t work for the version with BART because it isn’t continuous (same reason you can’t run BART in Stan, I assume). Looking at the source code for nutpie, am I correct that it doesn’t have a multi-threading option? i.e., bart_car_idata = nutpie.sample(compiled_model, , chains=4, blas_cores=16). I figure that would help with the RAM issue from a non-sparse matrix because it would partition the matrix over multiple threads. Maybe there’s a numba workaround for the sparse matrix with a decorator?
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run AdvancedSetSubtensor's perform method
warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run Eigvalsh{lower=True}'s perform method
warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run StructuredDot's perform method
warnings.warn(
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[33], line 1
----> 1 compiled_model = nutpie.compile_pymc_model(bart_car_model)
2 bart_car_idata = nutpie.sample(compiled_model)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:395, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
392 backend = "numba"
394 if backend.lower() == "numba":
--> 395 return _compile_pymc_model_numba(model, **kwargs)
396 elif backend.lower() == "jax":
397 return _compile_pymc_model_jax(
398 model, gradient_backend=gradient_backend, **kwargs
399 )
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
200 with warnings.catch_warnings():
201 warnings.filterwarnings(
202 "ignore",
203 message="Cannot cache compiled function .* as it uses dynamic globals",
204 category=numba.NumbaWarning,
205 )
--> 207 logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
211 n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
212 )
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
273 if cache:
274 res.enable_caching()
--> 275 res.compile()
276 return res
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
65 cres = self._cache.load_overload(self._sig,
66 self._targetdescr.target_context)
67 if cres is None:
---> 68 cres = self._compile_uncached()
69 self._cache.save_overload(self._sig, cres)
70 else:
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
79 sig = self._sig
81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:84, in _FunctionCompiler.compile(self, args, return_type)
82 return retval
83 else:
---> 84 raise retval
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91 pass
93 try:
---> 94 retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96 self._failed_cache[key] = e
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108 self.targetdescr.target_context,
109 impl,
110 args=args, return_type=return_type,
111 flags=flags, locals=self.locals,
112 pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740 compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743 args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:485, in CompilerBase._compile_core(self)
483 self.state.status.fail_reason = e
484 if is_final_pipeline:
--> 485 raise e
486 else:
487 raise CompilerError("All available pipelines exhausted")
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472 pm.run(self.state)
473 if self.state.cr is not None:
474 break
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
365 msg = "Failed in %s mode pipeline (step: %s)" % \
366 (self.pipeline_name, pass_desc)
367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
106 """
107 Type inference and legalization
108 """
109 with fallback_context(state, 'Function "%s" failed type inference'
110 % (state.func_id.func_name,)):
111 # Type inference
--> 112 typemap, return_type, calltypes, errs = type_inference_stage(
113 state.typingctx,
114 state.targetctx,
115 state.func_ir,
116 state.args,
117 state.return_type,
118 state.locals,
119 raise_errors=self._raise_errors)
120 state.typemap = typemap
121 # save errors in case of partial typing
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
91 infer.build_constraint()
92 # return errors in case of partial typing
---> 93 errs = infer.propagate(raise_errors=raise_errors)
94 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
96 return _TypingResults(typemap, restype, calltypes, errs)
File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typeinfer.py:1091, in TypeInferer.propagate(self, raise_errors)
1088 force_lit_args = [e for e in errors
1089 if isinstance(e, ForceLiteralArg)]
1090 if not force_lit_args:
-> 1091 raise errors[0]
1092 else:
1093 raise reduce(operator.or_, force_lit_args)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x15230dfa8860>, (csr_matrix,)
During: lowering "$446load_global.65 = global(sparse_constant: <Compressed Sparse Row sparse matrix of dtype 'float64'
with 8754 stored elements and shape (1633, 1633)>
Coords Values
(0, 26) 1.0
(0, 512) 1.0
(0, 513) 1.0
(0, 1504) 1.0
(0, 1547) 1.0
(1, 200) 1.0
(1, 483) 1.0
(1, 536) 1.0
(1, 1091) 1.0
(1, 1108) 1.0
(1, 1513) 1.0
(1, 1518) 1.0
(1, 1546) 1.0
(1, 1551) 1.0
(1, 1554) 1.0
(2, 3) 1.0
(2, 199) 1.0
(2, 244) 1.0
(2, 1622) 1.0
(3, 2) 1.0
(3, 11) 1.0
(3, 199) 1.0
(3, 249) 1.0
(3, 1550) 1.0
(3, 1622) 1.0
: :
(1628, 1629) 1.0
(1629, 18) 1.0
(1629, 1456) 1.0
(1629, 1458) 1.0
(1629, 1487) 1.0
(1629, 1627) 1.0
(1629, 1628) 1.0
(1630, 195) 1.0
(1630, 198) 1.0
(1630, 275) 1.0
(1630, 276) 1.0
(1630, 725) 1.0
(1630, 1004) 1.0
(1630, 1631) 1.0
(1631, 193) 1.0
(1631, 194) 1.0
(1631, 195) 1.0
(1631, 275) 1.0
(1631, 590) 1.0
(1631, 595) 1.0
(1631, 596) 1.0
(1631, 1630) 1.0
(1632, 193) 1.0
(1632, 196) 1.0
(1632, 197) 1.0)" at /tmp/tmpgeij1h_b (25)
During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)
During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)
File ".conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py", line 558:
def extract_shared(x, user_data_):
return inner(x)
^