Auto merge of #66282 - Centril:simplify-try, r= · rust-lang/rust@0113342 (original) (raw)

``

1

`+

//! The general point of the optimizations provided here is to simplify something like:

`

``

2

`+

//!

`

``

3


//! ```rust

``

4

`+

//! match x {

`

``

5

`+

//! Ok(x) => Ok(x),

`

``

6

`+

//! Err(x) => Err(x)

`

``

7

`+

//! }

`

``

8


//! ```

``

9

`+

//!

`

``

10

`` +

//! into just x.

``

``

11

+

``

12

`+

use crate::transform::{MirPass, MirSource, simplify};

`

``

13

`+

use rustc::ty::{TyCtxt, Ty};

`

``

14

`+

use rustc::mir::*;

`

``

15

`+

use rustc_target::abi::VariantIdx;

`

``

16

`+

use itertools::Itertools as _;

`

``

17

+

``

18

`` +

/// Simplifies arms of form Variant(x) => Variant(x) to just a move.

``

``

19

`+

///

`

``

20

`+

/// This is done by transforming basic blocks where the statements match:

`

``

21

`+

///

`

``

22


/// ```rust

``

23

`+

/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );

`

``

24

`+

/// ((_LOCAL_0 as Variant).FIELD: TY) = move _LOCAL_TMP;

`

``

25

`+

/// discriminant(_LOCAL_0) = VAR_IDX;

`

``

26


/// ```

``

27

`+

///

`

``

28

`+

/// into:

`

``

29

`+

///

`

``

30


/// ```rust

``

31

`+

/// _LOCAL_0 = move _LOCAL_1

`

``

32


/// ```

``

33

`+

pub struct SimplifyArmIdentity;

`

``

34

+

``

35

`+

impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {

`

``

36

`+

fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut Body<'tcx>) {

`

``

37

`+

for bb in body.basic_blocks_mut() {

`

``

38

`+

// Need 3 statements:

`

``

39

`+

let (s0, s1, s2) = match &mut *bb.statements {

`

``

40

`+

[s0, s1, s2] => (s0, s1, s2),

`

``

41

`+

_ => continue,

`

``

42

`+

};

`

``

43

+

``

44

`+

// Pattern match on the form we want:

`

``

45

`+

let (local_tmp_s0, local_1, vf_s0) = match match_get_variant_field(s0) {

`

``

46

`+

None => continue,

`

``

47

`+

Some(x) => x,

`

``

48

`+

};

`

``

49

`+

let (local_tmp_s1, local_0, vf_s1) = match match_set_variant_field(s1) {

`

``

50

`+

None => continue,

`

``

51

`+

Some(x) => x,

`

``

52

`+

};

`

``

53

`+

if local_tmp_s0 != local_tmp_s1

`

``

54

`+

|| vf_s0 != vf_s1

`

``

55

`+

|| Some((local_0, vf_s0.var_idx)) != match_set_discr(s2)

`

``

56

`+

{

`

``

57

`+

continue;

`

``

58

`+

}

`

``

59

+

``

60

`+

// Right shape; transform!

`

``

61

`+

match &mut s0.kind {

`

``

62

`+

StatementKind::Assign(box (place, rvalue)) => {

`

``

63

`+

*place = local_0.into();

`

``

64

`+

*rvalue = Rvalue::Use(Operand::Move(local_1.into()));

`

``

65

`+

}

`

``

66

`+

_ => unreachable!(),

`

``

67

`+

}

`

``

68

`+

s1.make_nop();

`

``

69

`+

s2.make_nop();

`

``

70

`+

}

`

``

71

`+

}

`

``

72

`+

}

`

``

73

+

``

74

`+

/// Match on:

`

``

75


/// ```rust

``

76

`+

/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);

`

``

77


/// ```

``

78

`+

fn match_get_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {

`

``

79

`+

match &stmt.kind {

`

``

80

`+

StatementKind::Assign(box (place_into, rvalue_from)) => match rvalue_from {

`

``

81

`+

Rvalue::Use(Operand::Copy(pf)) | Rvalue::Use(Operand::Move(pf)) => {

`

``

82

`+

let local_into = place_into.as_local()?;

`

``

83

`+

let (local_from, vf) = match_variant_field_place(&pf)?;

`

``

84

`+

Some((local_into, local_from, vf))

`

``

85

`+

}

`

``

86

`+

_ => None,

`

``

87

`+

},

`

``

88

`+

_ => None,

`

``

89

`+

}

`

``

90

`+

}

`

``

91

+

``

92

`+

/// Match on:

`

``

93


/// ```rust

``

94

`+

/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;

`

``

95


/// ```

``

96

`+

fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {

`

``

97

`+

match &stmt.kind {

`

``

98

`+

StatementKind::Assign(box (place_from, rvalue_into)) => match rvalue_into {

`

``

99

`+

Rvalue::Use(Operand::Move(place_into)) => {

`

``

100

`+

let local_into = place_into.as_local()?;

`

``

101

`+

let (local_from, vf) = match_variant_field_place(&place_from)?;

`

``

102

`+

Some((local_into, local_from, vf))

`

``

103

`+

}

`

``

104

`+

_ => None,

`

``

105

`+

},

`

``

106

`+

_ => None,

`

``

107

`+

}

`

``

108

`+

}

`

``

109

+

``

110

`+

/// Match on:

`

``

111


/// ```rust

``

112

`+

/// discriminant(_LOCAL_TO_SET) = VAR_IDX;

`

``

113


/// ```

``

114

`+

fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> {

`

``

115

`+

match &stmt.kind {

`

``

116

`+

StatementKind::SetDiscriminant { place, variant_index } => Some((

`

``

117

`+

place.as_local()?,

`

``

118

`+

*variant_index

`

``

119

`+

)),

`

``

120

`+

_ => None,

`

``

121

`+

}

`

``

122

`+

}

`

``

123

+

``

124

`+

#[derive(PartialEq)]

`

``

125

`+

struct VarField<'tcx> {

`

``

126

`+

field: Field,

`

``

127

`+

field_ty: Ty<'tcx>,

`

``

128

`+

var_idx: VariantIdx,

`

``

129

`+

}

`

``

130

+

``

131

`` +

/// Match on ((_LOCAL as Variant).FIELD: TY).

``

``

132

`+

fn match_variant_field_place<'tcx>(place: &Place<'tcx>) -> Option<(Local, VarField<'tcx>)> {

`

``

133

`+

match place.as_ref() {

`

``

134

`+

PlaceRef {

`

``

135

`+

base: &PlaceBase::Local(local),

`

``

136

`+

projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],

`

``

137

`+

} => Some((local, VarField { field, field_ty: ty, var_idx })),

`

``

138

`+

_ => None,

`

``

139

`+

}

`

``

140

`+

}

`

``

141

+

``

142

`` +

/// Simplifies SwitchInt(_) -> [targets],

``

``

143

`` +

/// where all the targets have the same form,

``

``

144

`` +

/// into goto -> target_first.

``

``

145

`+

pub struct SimplifyBranchSame;

`

``

146

+

``

147

`+

impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {

`

``

148

`+

fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut Body<'tcx>) {

`

``

149

`+

let bbs = body.basic_blocks_mut();

`

``

150

`+

for bb_idx in bbs.indices() {

`

``

151

`+

let targets = match &bbs[bb_idx].terminator().kind {

`

``

152

`+

TerminatorKind::SwitchInt { targets, .. } => targets,

`

``

153

`+

_ => continue,

`

``

154

`+

};

`

``

155

+

``

156

`` +

// Reaching unreachable is UB so assume it doesn't happen.

``

``

157

`+

let mut iter_bbs_reachable = targets

`

``

158

`+

.iter()

`

``

159

`+

.map(|idx| (*idx, &bbs[*idx]))

`

``

160

`+

.filter(|(_, bb)| bb.terminator().kind != TerminatorKind::Unreachable)

`

``

161

`+

.peekable();

`

``

162

+

``

163

`` +

// We want to goto -> bb_first.

``

``

164

`+

let bb_first = iter_bbs_reachable

`

``

165

`+

.peek()

`

``

166

`+

.map(|(idx, _)| *idx)

`

``

167

`+

.unwrap_or(targets[0]);

`

``

168

+

``

169

`+

// All successor basic blocks should have the exact same form.

`

``

170

`+

let all_successors_equivalent = iter_bbs_reachable

`

``

171

`+

.map(|(_, bb)| bb)

`

``

172

`+

.tuple_windows()

`

``

173

`+

.all(|(bb_l, bb_r)| {

`

``

174

`+

bb_l.is_cleanup == bb_r.is_cleanup

`

``

175

`+

&& bb_l.terminator().kind == bb_r.terminator().kind

`

``

176

`+

&& bb_l.statements.iter().eq_by(&bb_r.statements, |x, y| x.kind == y.kind)

`

``

177

`+

});

`

``

178

+

``

179

`+

if all_successors_equivalent {

`

``

180

`` +

// Replace SwitchInt(..) -> [bb_first, ..]; with a goto -> bb_first;.

``

``

181

`+

bbs[bb_idx].terminator_mut().kind = TerminatorKind::Goto { target: bb_first };

`

``

182

`+

}

`

``

183

`+

}

`

``

184

+

``

185

`+

// We may have dead blocks now, so remvoe those.

`

``

186

`+

simplify::remove_dead_blocks(body);

`

``

187

`+

}

`

``

188

`+

}

`

``

189

+

``

190

`+

/*

`

``

191

+

``

192

`+

KEEPSAKE: REMOVE IF NOT NECESSARY!

`

``

193

+

``

194

`+

fn statement_semantic_eq(sa: &StatementKind<'_>, sb: &StatementKind<'_>) -> bool {

`

``

195

`+

use StatementKind::*;

`

``

196

`+

match (sb, sa) {

`

``

197

`+

(AscribeUserType(pa, va), AscribeUserType(pb, vb)) => pa == pb && va == vb,

`

``

198

`+

(Assign(a), Assign(b)) => a == b,

`

``

199

`+

(FakeRead(fa, pa), FakeRead(fb, pb)) => fa == fb && pa == pb,

`

``

200

`+

(InlineAsm(a), InlineAsm(b)) => a == b,

`

``

201

`+

(Nop, StatementKind::Nop) => true,

`

``

202

`+

(Retag(ra, pa), Retag(rb, pb)) => ra == rb && pa == pb,

`

``

203

`+

(

`

``

204

`+

SetDiscriminant { place: pa, variant_index: va },

`

``

205

`+

SetDiscriminant { place: pb, variant_index: vb },

`

``

206

`+

) => pa == pb && va == vb,

`

``

207

`+

(StorageDead(a), StorageDead(b)) => a == b,

`

``

208

`+

(StorageLive(a), StorageLive(b)) => a == b,

`

``

209

`+

(AscribeUserType(..), ) | (, AscribeUserType(..))

`

``

210

`+

| (StorageDead(..), ) | (, StorageDead(..))

`

``

211

`+

| (Assign(..), ) | (, Assign(..))

`

``

212

`+

| (FakeRead(..), ) | (, FakeRead(..))

`

``

213

`+

| (InlineAsm(..), ) | (, InlineAsm(..))

`

``

214

`+

| (Nop, ) | (, Nop)

`

``

215

`+

| (Retag(..), ) | (, Retag(..))

`

``

216

`+

| (SetDiscriminant { .. }, ) | (, SetDiscriminant { .. }) => true,

`

``

217

`+

}

`

``

218

`+

}

`

``

219

+

``

220

`+

fn terminator_semantic_eq(ta: &TerminatorKind<'_>, tb: &TerminatorKind<'_>) -> bool {

`

``

221

`+

use TerminatorKind::*;

`

``

222

`+

match (ta, tb) {

`

``

223

`+

(Goto { target: a }, Goto { target: b }) => a == b,

`

``

224

`+

(Resume, Resume)

`

``

225

`+

| (Abort, Abort)

`

``

226

`+

| (Return, Return)

`

``

227

`+

| (Unreachable, Unreachable)

`

``

228

`+

| (GeneratorDrop, GeneratorDrop) => true,

`

``

229

`+

(

`

``

230

`+

SwitchInt { discr: da, switch_ty: sa, targets: ta, values: _ },

`

``

231

`+

SwitchInt { discr: db, switch_ty: sb, targets: tb, values: _ },

`

``

232

`+

) => da == db && sa == sb && ta == tb,

`

``

233

`+

(

`

``

234

`+

Drop { location: la, target: ta, unwind: ua },

`

``

235

`+

Drop { location: lb, target: tb, unwind: ub },

`

``

236

`+

) => la == lb && ta == tb && ua == ub,

`

``

237

`+

(

`

``

238

`+

DropAndReplace { location: la, target: ta, unwind: ua, value: va },

`

``

239

`+

DropAndReplace { location: lb, target: tb, unwind: ub, value: vb },

`

``

240

`+

) => la == lb && ta == tb && ua == ub && va == vb,

`

``

241

`+

(

`

``

242

`+

Call { func: fa, args: aa, destination: da, cleanup: ca, from_hir_call: _ },

`

``

243

`+

Call { func: fb, args: ab, destination: db, cleanup: cb, from_hir_call: _ },

`

``

244

`+

) => fa == fb && aa == ab && da == db && ca == cb,

`

``

245

`+

(

`

``

246

`+

Assert { cond: coa, expected: ea, msg: ma, target: ta, cleanup: cla },

`

``

247

`+

Assert { cond: cob, expected: eb, msg: mb, target: tb, cleanup: clb },

`

``

248

`+

) => coa == cob && ea == eb && ma == mb && ta == tb && cla == clb,

`

``

249

`+

(

`

``

250

`+

Yield { value: va, resume: ra, drop: da },

`

``

251

`+

Yield { value: vb, resume: rb, drop: db },

`

``

252

`+

) => va == vb && ra == rb && da == db,

`

``

253

`+

(

`

``

254

`+

FalseEdges { real_target: ra, imaginary_target: ia },

`

``

255

`+

FalseEdges { real_target: rb, imaginary_target: ib },

`

``

256

`+

) => ra == rb && ia == ib,

`

``

257

`+

(

`

``

258

`+

FalseUnwind { real_target: ra, unwind: ua },

`

``

259

`+

FalseUnwind { real_target: rb, unwind: ub },

`

``

260

`+

) => ra == rb && ua == ub,

`

``

261

`+

(Goto { .. }, ) | (, Goto { .. })

`

``

262

`+

| (Resume, ) | (, Resume)

`

``

263

`+

| (Abort, ) | (, Abort)

`

``

264

`+

| (Return, ) | (, Return)

`

``

265

`+

| (Unreachable, ) | (, Unreachable)

`

``

266

`+

| (GeneratorDrop, ) | (, GeneratorDrop)

`

``

267

`+

| (SwitchInt { .. }, ) | (, SwitchInt { .. })

`

``

268

`+

| (Drop { .. }, ) | (, Drop { .. })

`

``

269

`+

| (DropAndReplace { .. }, ) | (, DropAndReplace { .. })

`

``

270

`+

| (Call { .. }, ) | (, Call { .. })

`

``

271

`+

| (Assert { .. }, ) | (, Assert { .. })

`

``

272

`+

| (Yield { .. }, ) | (, Yield { .. })

`

``

273

`+

| (FalseEdges { .. }, ) | (, FalseEdges { .. }) => false,

`

``

274

`+

}

`

``

275

`+

}

`

``

276

`+

*/

`