adds message counting to protect against malicious overflow (#1067) · rsocket/rsocket-java@af021d9 (original) (raw)

`@@ -88,6 +88,8 @@ final class RequestChannelResponderSubscriber extends Flux

`

88

88

``

89

89

`boolean inboundDone;

`

90

90

`boolean outboundDone;

`

``

91

`+

long requested;

`

``

92

`+

long produced;

`

91

93

``

92

94

`public RequestChannelResponderSubscriber(

`

93

95

`int streamId,

`

`@@ -179,6 +181,8 @@ public void request(long n) {

`

179

181

`return;

`

180

182

` }

`

181

183

``

``

184

`+

this.requested = Operators.addCap(this.requested, n);

`

``

185

+

182

186

`long previousState = StateUtils.addRequestN(STATE, this, n);

`

183

187

`if (isTerminated(previousState)) {

`

184

188

`// full termination can be the result of both sides completion / cancelFrame / remote or local

`

`@@ -196,6 +200,9 @@ public void request(long n) {

`

196

200

`Payload firstPayload = this.firstPayload;

`

197

201

`if (firstPayload != null) {

`

198

202

`this.firstPayload = null;

`

``

203

+

``

204

`+

this.produced++;

`

``

205

+

199

206

`inboundSubscriber.onNext(firstPayload);

`

200

207

` }

`

201

208

``

`@@ -216,6 +223,8 @@ public void request(long n) {

`

216

223

`final Payload firstPayload = this.firstPayload;

`

217

224

`this.firstPayload = null;

`

218

225

``

``

226

`+

this.produced++;

`

``

227

+

219

228

`inboundSubscriber.onNext(firstPayload);

`

220

229

`inboundSubscriber.onComplete();

`

221

230

``

`@@ -238,6 +247,9 @@ public void request(long n) {

`

238

247

``

239

248

`final Payload firstPayload = this.firstPayload;

`

240

249

`this.firstPayload = null;

`

``

250

+

``

251

`+

this.produced++;

`

``

252

+

241

253

`inboundSubscriber.onNext(firstPayload);

`

242

254

``

243

255

`previousState = markFirstFrameSent(STATE, this);

`

`@@ -416,6 +428,58 @@ final void handlePayload(Payload p) {

`

416

428

`return;

`

417

429

` }

`

418

430

``

``

431

`+

final long produced = this.produced;

`

``

432

`+

if (this.requested == produced) {

`

``

433

`+

p.release();

`

``

434

+

``

435

`+

this.inboundDone = true;

`

``

436

+

``

437

`+

final Throwable cause =

`

``

438

`+

Exceptions.failWithOverflow(

`

``

439

`+

"The number of messages received exceeds the number requested");

`

``

440

`+

boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause);

`

``

441

+

``

442

`+

long previousState = markTerminated(STATE, this);

`

``

443

`+

if (isTerminated(previousState)) {

`

``

444

`+

if (!wasThrowableAdded) {

`

``

445

`+

Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext());

`

``

446

`+

}

`

``

447

`+

return;

`

``

448

`+

}

`

``

449

+

``

450

`+

this.requesterResponderSupport.remove(this.streamId, this);

`

``

451

+

``

452

`+

this.connection.sendFrame(

`

``

453

`+

streamId,

`

``

454

`+

ErrorFrameCodec.encode(

`

``

455

`+

this.allocator, streamId, new CanceledException(cause.getMessage())));

`

``

456

+

``

457

`+

if (!isSubscribed(previousState)) {

`

``

458

`+

final Payload firstPayload = this.firstPayload;

`

``

459

`+

this.firstPayload = null;

`

``

460

`+

firstPayload.release();

`

``

461

`+

} else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) {

`

``

462

`+

Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this);

`

``

463

`+

if (inboundError != TERMINATED) {

`

``

464

`+

//noinspection ConstantConditions

`

``

465

`+

this.inboundSubscriber.onError(inboundError);

`

``

466

`+

}

`

``

467

`+

}

`

``

468

+

``

469

`+

// this is downstream subscription so need to cancel it just in case error signal has not

`

``

470

`+

// reached it

`

``

471

`+

// needs for disconnected upstream and downstream case

`

``

472

`+

this.outboundSubscription.cancel();

`

``

473

+

``

474

`+

final RequestInterceptor interceptor = requestInterceptor;

`

``

475

`+

if (interceptor != null) {

`

``

476

`+

interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause);

`

``

477

`+

}

`

``

478

`+

return;

`

``

479

`+

}

`

``

480

+

``

481

`+

this.produced = produced + 1;

`

``

482

+

419

483

`this.inboundSubscriber.onNext(p);

`

420

484

` }

`

421

485

` }

`