refactor: Simplify chat memory advisor hierarchy and remove deprecate… · spring-projects/spring-ai@848a3fd (original) (raw)
`@@ -20,24 +20,25 @@
`
20
20
`import java.util.HashMap;
`
21
21
`import java.util.List;
`
22
22
`import java.util.Map;
`
23
``
`-
import java.util.stream.Collectors;
`
24
23
``
25
``
`-
import reactor.core.publisher.Flux;
`
``
24
`+
import org.slf4j.Logger;
`
``
25
`+
import org.slf4j.LoggerFactory;
`
``
26
`+
import reactor.core.scheduler.Scheduler;
`
``
27
`+
import reactor.core.scheduler.Schedulers;
`
26
28
``
27
``
`-
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
`
28
29
`import org.springframework.ai.chat.client.ChatClientRequest;
`
29
30
`import org.springframework.ai.chat.client.ChatClientResponse;
`
30
``
`-
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
`
31
``
`-
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
`
``
31
`+
import org.springframework.ai.chat.client.advisor.api.Advisor;
`
``
32
`+
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
`
``
33
`+
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
`
``
34
`+
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
`
``
35
`+
import org.springframework.ai.chat.memory.ChatMemory;
`
32
36
`import org.springframework.ai.chat.messages.AssistantMessage;
`
33
37
`import org.springframework.ai.chat.messages.Message;
`
34
38
`import org.springframework.ai.chat.messages.MessageType;
`
35
``
`-
import org.springframework.ai.chat.messages.SystemMessage;
`
36
39
`import org.springframework.ai.chat.messages.UserMessage;
`
37
``
`-
import org.springframework.ai.chat.model.MessageAggregator;
`
38
40
`import org.springframework.ai.chat.prompt.PromptTemplate;
`
39
41
`import org.springframework.ai.document.Document;
`
40
``
`-
import org.springframework.ai.vectorstore.SearchRequest;
`
41
42
`import org.springframework.ai.vectorstore.VectorStore;
`
42
43
``
43
44
`/**
`
`@@ -48,14 +49,22 @@
`
48
49
` * @author Christian Tzolov
`
49
50
` * @author Thomas Vitale
`
50
51
` * @author Oganes Bozoyan
`
``
52
`+
- @author Mark Pollack
`
51
53
` * @since 1.0.0
`
52
54
` */
`
53
``
`-
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor {
`
``
55
`+
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
`
``
56
+
``
57
`+
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
`
54
58
``
55
59
`private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
`
56
60
``
57
61
`private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
`
58
62
``
``
63
`+
/**
`
``
64
`+
- The default chat memory retrieve size to use when no retrieve size is provided.
`
``
65
`+
*/
`
``
66
`+
public static final int DEFAULT_TOP_K = 20;
`
``
67
+
59
68
`private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
`
60
69
` {instructions}
`
61
70
``
`@@ -69,71 +78,84 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
`
69
78
``
70
79
`private final PromptTemplate systemPromptTemplate;
`
71
80
``
72
``
`-
private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
`
73
``
`-
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
`
74
``
`-
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
`
``
81
`+
protected final int defaultChatMemoryRetrieveSize;
`
``
82
+
``
83
`+
private final String defaultConversationId;
`
``
84
+
``
85
`+
private final int order;
`
``
86
+
``
87
`+
private final Scheduler scheduler;
`
``
88
+
``
89
`+
private VectorStore vectorStore;
`
``
90
+
``
91
`+
public VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultChatMemoryRetrieveSize,
`
``
92
`+
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
`
75
93
`this.systemPromptTemplate = systemPromptTemplate;
`
``
94
`+
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
`
``
95
`+
this.defaultConversationId = defaultConversationId;
`
``
96
`+
this.order = order;
`
``
97
`+
this.scheduler = scheduler;
`
``
98
`+
this.vectorStore = vectorStore;
`
76
99
` }
`
77
100
``
78
101
`public static Builder builder(VectorStore chatMemory) {
`
79
102
`return new Builder(chatMemory);
`
80
103
` }
`
81
104
``
82
105
`@Override
`
83
``
`-
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
`
84
``
`-
chatClientRequest = this.before(chatClientRequest);
`
85
``
-
86
``
`-
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
`
87
``
-
88
``
`-
this.after(chatClientResponse);
`
89
``
-
90
``
`-
return chatClientResponse;
`
``
106
`+
public int getOrder() {
`
``
107
`+
return order;
`
91
108
` }
`
92
109
``
93
110
`@Override
`
94
``
`-
public Flux adviseStream(ChatClientRequest chatClientRequest,
`
95
``
`-
StreamAdvisorChain streamAdvisorChain) {
`
96
``
`-
Flux chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
`
97
``
`-
streamAdvisorChain, this::before);
`
98
``
-
99
``
`-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
`
``
111
`+
public Scheduler getScheduler() {
`
``
112
`+
return this.scheduler;
`
100
113
` }
`
101
114
``
102
``
`-
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
`
103
``
`-
String conversationId = this.doGetConversationId(chatClientRequest.context());
`
104
``
`-
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
`
105
``
-
106
``
`-
// 1. Retrieve the chat memory for the current conversation.
`
107
``
`-
var searchRequest = SearchRequest.builder()
`
108
``
`-
.query(chatClientRequest.prompt().getUserMessage().getText())
`
109
``
`-
.topK(chatMemoryRetrieveSize)
`
110
``
`-
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
`
``
115
`+
@Override
`
``
116
`+
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
`
``
117
`+
String conversationId = getConversationId(request.context());
`
``
118
`+
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
`
``
119
`+
int topK = getChatMemoryTopK(request.context());
`
``
120
`+
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
`
``
121
`+
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
`
``
122
`+
.query(query)
`
``
123
`+
.topK(topK)
`
``
124
`+
.filterExpression(filter)
`
111
125
` .build();
`
``
126
`+
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
`
``
127
`+
.similaritySearch(searchRequest);
`
112
128
``
113
``
`-
List documents = this.getChatMemoryStore().similaritySearch(searchRequest);
`
114
``
-
115
``
`-
// 2. Processed memory messages as a string.
`
116
129
`String longTermMemory = documents == null ? ""
`
117
``
`-
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
`
``
130
`+
: documents.stream()
`
``
131
`+
.map(org.springframework.ai.document.Document::getText)
`
``
132
`+
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
`
118
133
``
119
``
`-
// 2. Augment the system message.
`
120
``
`-
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
`
``
134
`+
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
`
121
135
`String augmentedSystemText = this.systemPromptTemplate
`
122
``
`-
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
`
``
136
`+
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
`
123
137
``
124
``
`-
// 3. Create a new request with the augmented system message.
`
125
``
`-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
`
126
``
`-
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
`
``
138
`+
ChatClientRequest processedChatClientRequest = request.mutate()
`
``
139
`+
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
`
127
140
` .build();
`
128
141
``
129
``
`-
// 4. Add the new user message to the conversation memory.
`
130
``
`-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
`
131
``
`-
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
`
``
142
`+
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
`
``
143
`+
.getUserMessage();
`
``
144
`+
if (userMessage != null) {
`
``
145
`+
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
`
``
146
`+
}
`
132
147
``
133
148
`return processedChatClientRequest;
`
134
149
` }
`
135
150
``
136
``
`-
private void after(ChatClientResponse chatClientResponse) {
`
``
151
`+
private int getChatMemoryTopK(Map<String, Object> context) {
`
``
152
`+
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
`
``
153
`+
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
`
``
154
`+
: this.defaultChatMemoryRetrieveSize;
`
``
155
`+
}
`
``
156
+
``
157
`+
@Override
`
``
158
`+
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
`
137
159
`List assistantMessages = new ArrayList<>();
`
138
160
`if (chatClientResponse.chatResponse() != null) {
`
139
161
`assistantMessages = chatClientResponse.chatResponse()
`
`@@ -142,8 +164,8 @@ private void after(ChatClientResponse chatClientResponse) {
`
142
164
` .map(g -> (Message) g.getOutput())
`
143
165
` .toList();
`
144
166
` }
`
145
``
`-
this.getChatMemoryStore()
`
146
``
`-
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
`
``
167
`+
this.vectorStore.write(toDocuments(assistantMessages, this.getConversationId(chatClientResponse.context())));
`
``
168
`+
return chatClientResponse;
`
147
169
` }
`
148
170
``
149
171
`private List toDocuments(List messages, String conversationId) {
`
`@@ -173,28 +195,93 @@ else if (message instanceof AssistantMessage assistantMessage) {
`
173
195
`return docs;
`
174
196
` }
`
175
197
``
176
``
`-
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder {
`
``
198
`+
/**
`
``
199
`+
- Builder for VectorStoreChatMemoryAdvisor.
`
``
200
`+
*/
`
``
201
`+
public static class Builder {
`
177
202
``
178
203
`private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
`
179
204
``
180
``
`-
protected Builder(VectorStore chatMemory) {
`
181
``
`-
super(chatMemory);
`
182
``
`-
}
`
``
205
`+
private Integer topK = DEFAULT_TOP_K;
`
183
206
``
184
``
`-
public Builder systemTextAdvise(String systemTextAdvise) {
`
185
``
`-
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
`
186
``
`-
return this;
`
``
207
`+
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
`
``
208
+
``
209
`+
private Scheduler scheduler;
`
``
210
+
``
211
`+
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
`
``
212
+
``
213
`+
private VectorStore vectorStore;
`
``
214
+
``
215
`+
/**
`
``
216
`+
- Creates a new builder instance.
`
``
217
`+
- @param vectorStore the vector store to use
`
``
218
`+
*/
`
``
219
`+
protected Builder(VectorStore vectorStore) {
`
``
220
`+
this.vectorStore = vectorStore;
`
187
221
` }
`
188
222
``
``
223
`+
/**
`
``
224
`+
- Set the system prompt template.
`
``
225
`+
- @param systemPromptTemplate the system prompt template
`
``
226
`+
- @return this builder
`
``
227
`+
*/
`
189
228
`public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
`
190
229
`this.systemPromptTemplate = systemPromptTemplate;
`
191
230
`return this;
`
192
231
` }
`
193
232
``
194
``
`-
@Override
`
``
233
`+
/**
`
``
234
`+
- Set the chat memory retrieve size.
`
``
235
`+
- @param topK the chat memory retrieve size
`
``
236
`+
- @return this builder
`
``
237
`+
*/
`
``
238
`+
public Builder topK(int topK) {
`
``
239
`+
this.topK = topK;
`
``
240
`+
return this;
`
``
241
`+
}
`
``
242
+
``
243
`+
/**
`
``
244
`+
- Set the conversation id.
`
``
245
`+
- @param conversationId the conversation id
`
``
246
`+
- @return the builder
`
``
247
`+
*/
`
``
248
`+
public Builder conversationId(String conversationId) {
`
``
249
`+
this.conversationId = conversationId;
`
``
250
`+
return this;
`
``
251
`+
}
`
``
252
+
``
253
`+
/**
`
``
254
`+
- Set whether to protect from blocking.
`
``
255
`+
- @param protectFromBlocking whether to protect from blocking
`
``
256
`+
- @return the builder
`
``
257
`+
*/
`
``
258
`+
public Builder protectFromBlocking(boolean protectFromBlocking) {
`
``
259
`+
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
`
``
260
`+
return this;
`
``
261
`+
}
`
``
262
+
``
263
`+
public Builder scheduler(Scheduler scheduler) {
`
``
264
`+
this.scheduler = scheduler;
`
``
265
`+
return this;
`
``
266
`+
}
`
``
267
+
``
268
`+
/**
`
``
269
`+
- Set the order.
`
``
270
`+
- @param order the order
`
``
271
`+
- @return the builder
`
``
272
`+
*/
`
``
273
`+
public Builder order(int order) {
`
``
274
`+
this.order = order;
`
``
275
`+
return this;
`
``
276
`+
}
`
``
277
+
``
278
`+
/**
`
``
279
`+
- Build the advisor.
`
``
280
`+
- @return the advisor
`
``
281
`+
*/
`
195
282
`public VectorStoreChatMemoryAdvisor build() {
`
196
``
`-
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
`
197
``
`-
this.protectFromBlocking, this.systemPromptTemplate, this.order);
`
``
283
`+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.topK, this.conversationId,
`
``
284
`+
this.order, this.scheduler, this.vectorStore);
`
198
285
` }
`
199
286
``
200
287
` }
`