Avoid re-computing computation hashes by rpsilva-aws · Pull Request #8976 · pytorch/xla (original) (raw)

Currently, we are recomputing the hash of the underlying computation for every hash lookup, as a mere logging in two places. For small models where tracing is not negligible, this can have a small impact - particularly since we deserialize the protobuf deterministically (requiring the ordering of unordered dictionary/map entries). The logging was unchanged, but the underlying deserialization logic is relatively slower, in order to guarantee deterministic hashes for user computations. C++'s evaluates stream operators eagerly, so the impact is there with or without the logging levels.

This is only observed if the model is tracing bound. We recently saw ~5% throughput impact for small BERT models.

Note that this is only used to provide an unique hash string for which a hash key maps to. The actual hash of the protobuf is only meaningful for UserComputation computations, where it is factored in the hash key. In all other cases, it is unnecessary and serves as an unique (debug) identifier, and the user can still verify the mapping for any given graph hash key when enabling post_compilation_analysis.

We see this during hash lookup, which is evaluated every time. We also see it in Compile, though it is there only for the very first computation (across all instances). The user can still access the computation proto hash by enabling PT_XLA_DEBUG.

e.g. for BERT HF pretraining (20 steps) - 48 metrics with 27 samples each, the collective tracing of each hash computation metric is as follows:

- Average Rate: ~1.98 operations/second
- Most rates fall between 1.4-2.5 ops/second with a few outliers
- Highest Rate: 7.26772 ops/second (outlier)
- Lowest Rate: ~1.42 ops/second

- Typical p50 (median) latency per op: ~8-9 microseconds
- Typical p95 latency per op: ~450-500 microseconds
- Typical p99 latency per op: ~500-600 microseconds