Implementing BART with CAR prior (original) (raw)

The full traceback using the sparse matrix is below. I’d also tried running the non-BART model using car_idata = pm.sample(200, tune=200, nuts_sampler=“nutpie”, chains=4, blas_cores=16) to see if I could get a speedup. This syntax doesn’t work for the version with BART because it isn’t continuous (same reason you can’t run BART in Stan, I assume). Looking at the source code for nutpie, am I correct that it doesn’t have a multi-threading option? i.e., bart_car_idata = nutpie.sample(compiled_model, , chains=4, blas_cores=16). I figure that would help with the RAM issue from a non-sparse matrix because it would partition the matrix over multiple threads. Maybe there’s a numba workaround for the sparse matrix with a decorator?

/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run AdvancedSetSubtensor's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run SparseDot's perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run Eigvalsh{lower=True}'s perform method
  warnings.warn(
/home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:381: UserWarning: Numba will use object mode to run StructuredDot's perform method
  warnings.warn(
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[33], line 1
----> 1 compiled_model = nutpie.compile_pymc_model(bart_car_model)
      2 bart_car_idata = nutpie.sample(compiled_model)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:395, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
    392     backend = "numba"
    394 if backend.lower() == "numba":
--> 395     return _compile_pymc_model_numba(model, **kwargs)
    396 elif backend.lower() == "jax":
    397     return _compile_pymc_model_jax(
    398         model, gradient_backend=gradient_backend, **kwargs
    399     )

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
    200 with warnings.catch_warnings():
    201     warnings.filterwarnings(
    202         "ignore",
    203         message="Cannot cache compiled function .* as it uses dynamic globals",
    204         category=numba.NumbaWarning,
    205     )
--> 207     logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
    210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
    211     n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
    212 )

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
    273 if cache:
    274     res.enable_caching()
--> 275 res.compile()
    276 return res

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
     65 cres = self._cache.load_overload(self._sig,
     66                                  self._targetdescr.target_context)
     67 if cres is None:
---> 68     cres = self._compile_uncached()
     69     self._cache.save_overload(self._sig, cres)
     70 else:

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
     79 sig = self._sig
     81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:84, in _FunctionCompiler.compile(self, args, return_type)
     82     return retval
     83 else:
---> 84     raise retval

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721 
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:485, in CompilerBase._compile_core(self)
    483         self.state.status.fail_reason = e
    484         if is_final_pipeline:
--> 485             raise e
    486 else:
    487     raise CompilerError("All available pipelines exhausted")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
    365 msg = "Failed in %s mode pipeline (step: %s)" % \
    366     (self.pipeline_name, pass_desc)
    367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File ~/.conda/envs/pymc_env/lib/python3.12/site-packages/numba/core/typeinfer.py:1091, in TypeInferer.propagate(self, raise_errors)
   1088 force_lit_args = [e for e in errors
   1089                   if isinstance(e, ForceLiteralArg)]
   1090 if not force_lit_args:
-> 1091     raise errors[0]
   1092 else:
   1093     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x15230dfa8860>, (csr_matrix,)
During: lowering "$446load_global.65 = global(sparse_constant: <Compressed Sparse Row sparse matrix of dtype 'float64'
    with 8754 stored elements and shape (1633, 1633)>
  Coords	Values
  (0, 26)	1.0
  (0, 512)	1.0
  (0, 513)	1.0
  (0, 1504)	1.0
  (0, 1547)	1.0
  (1, 200)	1.0
  (1, 483)	1.0
  (1, 536)	1.0
  (1, 1091)	1.0
  (1, 1108)	1.0
  (1, 1513)	1.0
  (1, 1518)	1.0
  (1, 1546)	1.0
  (1, 1551)	1.0
  (1, 1554)	1.0
  (2, 3)	1.0
  (2, 199)	1.0
  (2, 244)	1.0
  (2, 1622)	1.0
  (3, 2)	1.0
  (3, 11)	1.0
  (3, 199)	1.0
  (3, 249)	1.0
  (3, 1550)	1.0
  (3, 1622)	1.0
  :	:
  (1628, 1629)	1.0
  (1629, 18)	1.0
  (1629, 1456)	1.0
  (1629, 1458)	1.0
  (1629, 1487)	1.0
  (1629, 1627)	1.0
  (1629, 1628)	1.0
  (1630, 195)	1.0
  (1630, 198)	1.0
  (1630, 275)	1.0
  (1630, 276)	1.0
  (1630, 725)	1.0
  (1630, 1004)	1.0
  (1630, 1631)	1.0
  (1631, 193)	1.0
  (1631, 194)	1.0
  (1631, 195)	1.0
  (1631, 275)	1.0
  (1631, 590)	1.0
  (1631, 595)	1.0
  (1631, 596)	1.0
  (1631, 1630)	1.0
  (1632, 193)	1.0
  (1632, 196)	1.0
  (1632, 197)	1.0)" at /tmp/tmpgeij1h_b (25)
During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)

During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x1523059a1300>))
During: typing of call at /home/hawkinslab/jfhawkin/.conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py (558)


File ".conda/envs/pymc_env/lib/python3.12/site-packages/nutpie/compile_pymc.py", line 558:
        def extract_shared(x, user_data_):
            return inner(x)
            ^