Fix zero-1 bug for inferring local ranks by YangFei1990 · Pull Request #5936 · pytorch/xla (original) (raw)

Hi @alanwaketan for the example in the description, say if user put the groups as [[0-7], [8-15], [16-23], [24-31]], i.e. opt = ZeroRedundancyOptimizer(..., sharding_group=[[0-7], [8-15], [16-23], [24-31]], ...). With the existing local rank inferring method self.local_rank = self.global_rank // len(self.sharding_groups), for example global rank [0-7] will have all marked as local rank 0. However the local rank should be the rank of the group that the current rank belongs to, so the global rank [0-7] should also have local rank [0-7]. This PR is to address this issue by infer local rank from the index of its global rank in its sharding group.
If you want to add a test for this, could you share some guidance how could I add the test?