refactor: Move MessageAggregator to spring-ai-model module · spring-projects/spring-ai@54e5c07 (original) (raw)

``

1

`+

/*

`

``

2

`+

`

``

3

`+

`

``

4

`+

`

``

5

`+

`

``

6

`+

`

``

7

`+

`

``

8

`+

`

``

9

`+

`

``

10

`+

`

``

11

`+

`

``

12

`+

`

``

13

`+

`

``

14

`+

`

``

15

`+

*/

`

``

16

+

``

17

`+

package org.springframework.ai.chat.client;

`

``

18

+

``

19

`+

import java.util.HashMap;

`

``

20

`+

import java.util.Map;

`

``

21

`+

import java.util.concurrent.atomic.AtomicReference;

`

``

22

`+

import java.util.function.Consumer;

`

``

23

+

``

24

`+

import org.slf4j.Logger;

`

``

25

`+

import org.slf4j.LoggerFactory;

`

``

26

`+

import reactor.core.publisher.Flux;

`

``

27

+

``

28

`+

import org.springframework.ai.chat.model.MessageAggregator;

`

``

29

+

``

30

`+

/**

`

``

31

`+

`

``

32

`+

`

``

33

`+

`

``

34

`+

`

``

35

`+

`

``

36

`+

`

``

37

`+

`

``

38

`+

*/

`

``

39

`+

public class ChatClientMessageAggregator {

`

``

40

+

``

41

`+

private static final Logger logger = LoggerFactory.getLogger(ChatClientMessageAggregator.class);

`

``

42

+

``

43

`+

public Flux aggregateChatClientResponse(Flux chatClientResponses,

`

``

44

`+

Consumer aggregationHandler) {

`

``

45

+

``

46

`+

AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());

`

``

47

+

``

48

`+

return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {

`

``

49

`+

context.get().putAll(chatClientResponse.context());

`

``

50

`+

return chatClientResponse.chatResponse();

`

``

51

`+

}), aggregatedChatResponse -> {

`

``

52

`+

ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()

`

``

53

`+

.chatResponse(aggregatedChatResponse)

`

``

54

`+

.context(context.get())

`

``

55

`+

.build();

`

``

56

`+

aggregationHandler.accept(aggregatedChatClientResponse);

`

``

57

`+

}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build());

`

``

58

`+

}

`

``

59

+

``

60

`+

}

`