adds message counting to protect against malicious overflow (#1067) · rsocket/rsocket-java@af021d9 (original) (raw)
`@@ -88,6 +88,8 @@ final class RequestChannelResponderSubscriber extends Flux
`
88
88
``
89
89
`boolean inboundDone;
`
90
90
`boolean outboundDone;
`
``
91
`+
long requested;
`
``
92
`+
long produced;
`
91
93
``
92
94
`public RequestChannelResponderSubscriber(
`
93
95
`int streamId,
`
`@@ -179,6 +181,8 @@ public void request(long n) {
`
179
181
`return;
`
180
182
` }
`
181
183
``
``
184
`+
this.requested = Operators.addCap(this.requested, n);
`
``
185
+
182
186
`long previousState = StateUtils.addRequestN(STATE, this, n);
`
183
187
`if (isTerminated(previousState)) {
`
184
188
`// full termination can be the result of both sides completion / cancelFrame / remote or local
`
`@@ -196,6 +200,9 @@ public void request(long n) {
`
196
200
`Payload firstPayload = this.firstPayload;
`
197
201
`if (firstPayload != null) {
`
198
202
`this.firstPayload = null;
`
``
203
+
``
204
`+
this.produced++;
`
``
205
+
199
206
`inboundSubscriber.onNext(firstPayload);
`
200
207
` }
`
201
208
``
`@@ -216,6 +223,8 @@ public void request(long n) {
`
216
223
`final Payload firstPayload = this.firstPayload;
`
217
224
`this.firstPayload = null;
`
218
225
``
``
226
`+
this.produced++;
`
``
227
+
219
228
`inboundSubscriber.onNext(firstPayload);
`
220
229
`inboundSubscriber.onComplete();
`
221
230
``
`@@ -238,6 +247,9 @@ public void request(long n) {
`
238
247
``
239
248
`final Payload firstPayload = this.firstPayload;
`
240
249
`this.firstPayload = null;
`
``
250
+
``
251
`+
this.produced++;
`
``
252
+
241
253
`inboundSubscriber.onNext(firstPayload);
`
242
254
``
243
255
`previousState = markFirstFrameSent(STATE, this);
`
`@@ -416,6 +428,58 @@ final void handlePayload(Payload p) {
`
416
428
`return;
`
417
429
` }
`
418
430
``
``
431
`+
final long produced = this.produced;
`
``
432
`+
if (this.requested == produced) {
`
``
433
`+
p.release();
`
``
434
+
``
435
`+
this.inboundDone = true;
`
``
436
+
``
437
`+
final Throwable cause =
`
``
438
`+
Exceptions.failWithOverflow(
`
``
439
`+
"The number of messages received exceeds the number requested");
`
``
440
`+
boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause);
`
``
441
+
``
442
`+
long previousState = markTerminated(STATE, this);
`
``
443
`+
if (isTerminated(previousState)) {
`
``
444
`+
if (!wasThrowableAdded) {
`
``
445
`+
Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext());
`
``
446
`+
}
`
``
447
`+
return;
`
``
448
`+
}
`
``
449
+
``
450
`+
this.requesterResponderSupport.remove(this.streamId, this);
`
``
451
+
``
452
`+
this.connection.sendFrame(
`
``
453
`+
streamId,
`
``
454
`+
ErrorFrameCodec.encode(
`
``
455
`+
this.allocator, streamId, new CanceledException(cause.getMessage())));
`
``
456
+
``
457
`+
if (!isSubscribed(previousState)) {
`
``
458
`+
final Payload firstPayload = this.firstPayload;
`
``
459
`+
this.firstPayload = null;
`
``
460
`+
firstPayload.release();
`
``
461
`+
} else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) {
`
``
462
`+
Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this);
`
``
463
`+
if (inboundError != TERMINATED) {
`
``
464
`+
//noinspection ConstantConditions
`
``
465
`+
this.inboundSubscriber.onError(inboundError);
`
``
466
`+
}
`
``
467
`+
}
`
``
468
+
``
469
`+
// this is downstream subscription so need to cancel it just in case error signal has not
`
``
470
`+
// reached it
`
``
471
`+
// needs for disconnected upstream and downstream case
`
``
472
`+
this.outboundSubscription.cancel();
`
``
473
+
``
474
`+
final RequestInterceptor interceptor = requestInterceptor;
`
``
475
`+
if (interceptor != null) {
`
``
476
`+
interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause);
`
``
477
`+
}
`
``
478
`+
return;
`
``
479
`+
}
`
``
480
+
``
481
`+
this.produced = produced + 1;
`
``
482
+
419
483
`this.inboundSubscriber.onNext(p);
`
420
484
` }
`
421
485
` }
`