refactor: Move MessageAggregator to spring-ai-model module · spring-projects/spring-ai@54e5c07 (original) (raw)
``
1
`+
/*
`
``
2
`+
- Copyright 2023-2025 the original author or authors.
`
``
3
`+
`
``
4
`+
- Licensed under the Apache License, Version 2.0 (the "License");
`
``
5
`+
- you may not use this file except in compliance with the License.
`
``
6
`+
- You may obtain a copy of the License at
`
``
7
`+
`
``
8
`+
`
``
9
`+
`
``
10
`+
- Unless required by applicable law or agreed to in writing, software
`
``
11
`+
- distributed under the License is distributed on an "AS IS" BASIS,
`
``
12
`+
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
`
``
13
`+
- See the License for the specific language governing permissions and
`
``
14
`+
- limitations under the License.
`
``
15
`+
*/
`
``
16
+
``
17
`+
package org.springframework.ai.chat.client;
`
``
18
+
``
19
`+
import java.util.HashMap;
`
``
20
`+
import java.util.Map;
`
``
21
`+
import java.util.concurrent.atomic.AtomicReference;
`
``
22
`+
import java.util.function.Consumer;
`
``
23
+
``
24
`+
import org.slf4j.Logger;
`
``
25
`+
import org.slf4j.LoggerFactory;
`
``
26
`+
import reactor.core.publisher.Flux;
`
``
27
+
``
28
`+
import org.springframework.ai.chat.model.MessageAggregator;
`
``
29
+
``
30
`+
/**
`
``
31
`+
- Helper that for streaming chat responses, aggregate the chat response messages into a
`
``
32
`+
- single AssistantMessage. Job is performed in parallel to the chat response processing.
`
``
33
`+
`
``
34
`+
- @author Christian Tzolov
`
``
35
`+
- @author Alexandros Pappas
`
``
36
`+
- @author Thomas Vitale
`
``
37
`+
- @since 1.0.0
`
``
38
`+
*/
`
``
39
`+
public class ChatClientMessageAggregator {
`
``
40
+
``
41
`+
private static final Logger logger = LoggerFactory.getLogger(ChatClientMessageAggregator.class);
`
``
42
+
``
43
`+
public Flux aggregateChatClientResponse(Flux chatClientResponses,
`
``
44
`+
Consumer aggregationHandler) {
`
``
45
+
``
46
`+
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());
`
``
47
+
``
48
`+
return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
`
``
49
`+
context.get().putAll(chatClientResponse.context());
`
``
50
`+
return chatClientResponse.chatResponse();
`
``
51
`+
}), aggregatedChatResponse -> {
`
``
52
`+
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
`
``
53
`+
.chatResponse(aggregatedChatResponse)
`
``
54
`+
.context(context.get())
`
``
55
`+
.build();
`
``
56
`+
aggregationHandler.accept(aggregatedChatClientResponse);
`
``
57
`+
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build());
`
``
58
`+
}
`
``
59
+
``
60
`+
}
`