remove using_pjrt in xla_graph_executor (#6768) · pytorch/xla@d6fb539 (original) (raw)

`@@ -1285,8 +1285,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(

`

1285

1285

`runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);

`

1286

1286

`static const size_t parameter_wrapping_threadshold =

`

1287

1287

`runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);

`

1288

``

`-

static const bool using_pjrt =

`

1289

``

`-

runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0;

`

1290

1288

`static const bool use_autosharding = ShardingUtil::GetAutoSharding();

`

1291

1289

` LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,

`

1292

1290

` po_data->post_order,

`

`@@ -1346,7 +1344,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(

`

1346

1344

`// TODO(yeounoh) enable wrapping with auto-sharding.

`

1347

1345

`bool should_wrap_parameter =

`

1348

1346

` (program_shape.parameters_size() >= parameter_wrapping_threadshold) &&

`

1349

``

`-

using_pjrt && !use_autosharding;

`

``

1347

`+

!use_autosharding;

`

1350

1348

`if (should_wrap_parameter) {

`

1351

1349

`TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size()

`

1352

1350

` << " parameters. Threadshold = "

`