remove using_pjrt in xla_graph_executor (#6768) · pytorch/xla@d6fb539 (original) (raw)
`@@ -1285,8 +1285,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
`
1285
1285
`runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
`
1286
1286
`static const size_t parameter_wrapping_threadshold =
`
1287
1287
`runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
`
1288
``
`-
static const bool using_pjrt =
`
1289
``
`-
runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0;
`
1290
1288
`static const bool use_autosharding = ShardingUtil::GetAutoSharding();
`
1291
1289
` LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,
`
1292
1290
` po_data->post_order,
`
`@@ -1346,7 +1344,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
`
1346
1344
`// TODO(yeounoh) enable wrapping with auto-sharding.
`
1347
1345
`bool should_wrap_parameter =
`
1348
1346
` (program_shape.parameters_size() >= parameter_wrapping_threadshold) &&
`
1349
``
`-
using_pjrt && !use_autosharding;
`
``
1347
`+
!use_autosharding;
`
1350
1348
`if (should_wrap_parameter) {
`
1351
1349
`TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size()
`
1352
1350
` << " parameters. Threadshold = "
`