refactor: Simplify chat memory advisor hierarchy and remove deprecate… · spring-projects/spring-ai@848a3fd (original) (raw)

`@@ -20,24 +20,25 @@

`

20

20

`import java.util.HashMap;

`

21

21

`import java.util.List;

`

22

22

`import java.util.Map;

`

23

``

`-

import java.util.stream.Collectors;

`

24

23

``

25

``

`-

import reactor.core.publisher.Flux;

`

``

24

`+

import org.slf4j.Logger;

`

``

25

`+

import org.slf4j.LoggerFactory;

`

``

26

`+

import reactor.core.scheduler.Scheduler;

`

``

27

`+

import reactor.core.scheduler.Schedulers;

`

26

28

``

27

``

`-

import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;

`

28

29

`import org.springframework.ai.chat.client.ChatClientRequest;

`

29

30

`import org.springframework.ai.chat.client.ChatClientResponse;

`

30

``

`-

import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;

`

31

``

`-

import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;

`

``

31

`+

import org.springframework.ai.chat.client.advisor.api.Advisor;

`

``

32

`+

import org.springframework.ai.chat.client.advisor.api.AdvisorChain;

`

``

33

`+

import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;

`

``

34

`+

import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;

`

``

35

`+

import org.springframework.ai.chat.memory.ChatMemory;

`

32

36

`import org.springframework.ai.chat.messages.AssistantMessage;

`

33

37

`import org.springframework.ai.chat.messages.Message;

`

34

38

`import org.springframework.ai.chat.messages.MessageType;

`

35

``

`-

import org.springframework.ai.chat.messages.SystemMessage;

`

36

39

`import org.springframework.ai.chat.messages.UserMessage;

`

37

``

`-

import org.springframework.ai.chat.model.MessageAggregator;

`

38

40

`import org.springframework.ai.chat.prompt.PromptTemplate;

`

39

41

`import org.springframework.ai.document.Document;

`

40

``

`-

import org.springframework.ai.vectorstore.SearchRequest;

`

41

42

`import org.springframework.ai.vectorstore.VectorStore;

`

42

43

``

43

44

`/**

`

`@@ -48,14 +49,22 @@

`

48

49

` * @author Christian Tzolov

`

49

50

` * @author Thomas Vitale

`

50

51

` * @author Oganes Bozoyan

`

``

52

`+

`

51

53

` * @since 1.0.0

`

52

54

` */

`

53

``

`-

public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor {

`

``

55

`+

public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {

`

``

56

+

``

57

`+

public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";

`

54

58

``

55

59

`private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";

`

56

60

``

57

61

`private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";

`

58

62

``

``

63

`+

/**

`

``

64

`+

`

``

65

`+

*/

`

``

66

`+

public static final int DEFAULT_TOP_K = 20;

`

``

67

+

59

68

`private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""

`

60

69

` {instructions}

`

61

70

``

`@@ -69,71 +78,84 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect

`

69

78

``

70

79

`private final PromptTemplate systemPromptTemplate;

`

71

80

``

72

``

`-

private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,

`

73

``

`-

int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {

`

74

``

`-

super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);

`

``

81

`+

protected final int defaultChatMemoryRetrieveSize;

`

``

82

+

``

83

`+

private final String defaultConversationId;

`

``

84

+

``

85

`+

private final int order;

`

``

86

+

``

87

`+

private final Scheduler scheduler;

`

``

88

+

``

89

`+

private VectorStore vectorStore;

`

``

90

+

``

91

`+

public VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultChatMemoryRetrieveSize,

`

``

92

`+

String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {

`

75

93

`this.systemPromptTemplate = systemPromptTemplate;

`

``

94

`+

this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;

`

``

95

`+

this.defaultConversationId = defaultConversationId;

`

``

96

`+

this.order = order;

`

``

97

`+

this.scheduler = scheduler;

`

``

98

`+

this.vectorStore = vectorStore;

`

76

99

` }

`

77

100

``

78

101

`public static Builder builder(VectorStore chatMemory) {

`

79

102

`return new Builder(chatMemory);

`

80

103

` }

`

81

104

``

82

105

`@Override

`

83

``

`-

public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {

`

84

``

`-

chatClientRequest = this.before(chatClientRequest);

`

85

``

-

86

``

`-

ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);

`

87

``

-

88

``

`-

this.after(chatClientResponse);

`

89

``

-

90

``

`-

return chatClientResponse;

`

``

106

`+

public int getOrder() {

`

``

107

`+

return order;

`

91

108

` }

`

92

109

``

93

110

`@Override

`

94

``

`-

public Flux adviseStream(ChatClientRequest chatClientRequest,

`

95

``

`-

StreamAdvisorChain streamAdvisorChain) {

`

96

``

`-

Flux chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,

`

97

``

`-

streamAdvisorChain, this::before);

`

98

``

-

99

``

`-

return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);

`

``

111

`+

public Scheduler getScheduler() {

`

``

112

`+

return this.scheduler;

`

100

113

` }

`

101

114

``

102

``

`-

private ChatClientRequest before(ChatClientRequest chatClientRequest) {

`

103

``

`-

String conversationId = this.doGetConversationId(chatClientRequest.context());

`

104

``

`-

int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());

`

105

``

-

106

``

`-

// 1. Retrieve the chat memory for the current conversation.

`

107

``

`-

var searchRequest = SearchRequest.builder()

`

108

``

`-

.query(chatClientRequest.prompt().getUserMessage().getText())

`

109

``

`-

.topK(chatMemoryRetrieveSize)

`

110

``

`-

.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")

`

``

115

`+

@Override

`

``

116

`+

public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {

`

``

117

`+

String conversationId = getConversationId(request.context());

`

``

118

`+

String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";

`

``

119

`+

int topK = getChatMemoryTopK(request.context());

`

``

120

`+

String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";

`

``

121

`+

var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()

`

``

122

`+

.query(query)

`

``

123

`+

.topK(topK)

`

``

124

`+

.filterExpression(filter)

`

111

125

` .build();

`

``

126

`+

java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore

`

``

127

`+

.similaritySearch(searchRequest);

`

112

128

``

113

``

`-

List documents = this.getChatMemoryStore().similaritySearch(searchRequest);

`

114

``

-

115

``

`-

// 2. Processed memory messages as a string.

`

116

129

`String longTermMemory = documents == null ? ""

`

117

``

`-

: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));

`

``

130

`+

: documents.stream()

`

``

131

`+

.map(org.springframework.ai.document.Document::getText)

`

``

132

`+

.collect(java.util.stream.Collectors.joining(System.lineSeparator()));

`

118

133

``

119

``

`-

// 2. Augment the system message.

`

120

``

`-

SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();

`

``

134

`+

org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();

`

121

135

`String augmentedSystemText = this.systemPromptTemplate

`

122

``

`-

.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));

`

``

136

`+

.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));

`

123

137

``

124

``

`-

// 3. Create a new request with the augmented system message.

`

125

``

`-

ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()

`

126

``

`-

.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))

`

``

138

`+

ChatClientRequest processedChatClientRequest = request.mutate()

`

``

139

`+

.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))

`

127

140

` .build();

`

128

141

``

129

``

`-

// 4. Add the new user message to the conversation memory.

`

130

``

`-

UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();

`

131

``

`-

this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));

`

``

142

`+

org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()

`

``

143

`+

.getUserMessage();

`

``

144

`+

if (userMessage != null) {

`

``

145

`+

this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));

`

``

146

`+

}

`

132

147

``

133

148

`return processedChatClientRequest;

`

134

149

` }

`

135

150

``

136

``

`-

private void after(ChatClientResponse chatClientResponse) {

`

``

151

`+

private int getChatMemoryTopK(Map<String, Object> context) {

`

``

152

`+

return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)

`

``

153

`+

? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())

`

``

154

`+

: this.defaultChatMemoryRetrieveSize;

`

``

155

`+

}

`

``

156

+

``

157

`+

@Override

`

``

158

`+

public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {

`

137

159

`List assistantMessages = new ArrayList<>();

`

138

160

`if (chatClientResponse.chatResponse() != null) {

`

139

161

`assistantMessages = chatClientResponse.chatResponse()

`

`@@ -142,8 +164,8 @@ private void after(ChatClientResponse chatClientResponse) {

`

142

164

` .map(g -> (Message) g.getOutput())

`

143

165

` .toList();

`

144

166

` }

`

145

``

`-

this.getChatMemoryStore()

`

146

``

`-

.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));

`

``

167

`+

this.vectorStore.write(toDocuments(assistantMessages, this.getConversationId(chatClientResponse.context())));

`

``

168

`+

return chatClientResponse;

`

147

169

` }

`

148

170

``

149

171

`private List toDocuments(List messages, String conversationId) {

`

`@@ -173,28 +195,93 @@ else if (message instanceof AssistantMessage assistantMessage) {

`

173

195

`return docs;

`

174

196

` }

`

175

197

``

176

``

`-

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder {

`

``

198

`+

/**

`

``

199

`+

`

``

200

`+

*/

`

``

201

`+

public static class Builder {

`

177

202

``

178

203

`private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;

`

179

204

``

180

``

`-

protected Builder(VectorStore chatMemory) {

`

181

``

`-

super(chatMemory);

`

182

``

`-

}

`

``

205

`+

private Integer topK = DEFAULT_TOP_K;

`

183

206

``

184

``

`-

public Builder systemTextAdvise(String systemTextAdvise) {

`

185

``

`-

this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);

`

186

``

`-

return this;

`

``

207

`+

private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;

`

``

208

+

``

209

`+

private Scheduler scheduler;

`

``

210

+

``

211

`+

private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;

`

``

212

+

``

213

`+

private VectorStore vectorStore;

`

``

214

+

``

215

`+

/**

`

``

216

`+

`

``

217

`+

`

``

218

`+

*/

`

``

219

`+

protected Builder(VectorStore vectorStore) {

`

``

220

`+

this.vectorStore = vectorStore;

`

187

221

` }

`

188

222

``

``

223

`+

/**

`

``

224

`+

`

``

225

`+

`

``

226

`+

`

``

227

`+

*/

`

189

228

`public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {

`

190

229

`this.systemPromptTemplate = systemPromptTemplate;

`

191

230

`return this;

`

192

231

` }

`

193

232

``

194

``

`-

@Override

`

``

233

`+

/**

`

``

234

`+

`

``

235

`+

`

``

236

`+

`

``

237

`+

*/

`

``

238

`+

public Builder topK(int topK) {

`

``

239

`+

this.topK = topK;

`

``

240

`+

return this;

`

``

241

`+

}

`

``

242

+

``

243

`+

/**

`

``

244

`+

`

``

245

`+

`

``

246

`+

`

``

247

`+

*/

`

``

248

`+

public Builder conversationId(String conversationId) {

`

``

249

`+

this.conversationId = conversationId;

`

``

250

`+

return this;

`

``

251

`+

}

`

``

252

+

``

253

`+

/**

`

``

254

`+

`

``

255

`+

`

``

256

`+

`

``

257

`+

*/

`

``

258

`+

public Builder protectFromBlocking(boolean protectFromBlocking) {

`

``

259

`+

this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();

`

``

260

`+

return this;

`

``

261

`+

}

`

``

262

+

``

263

`+

public Builder scheduler(Scheduler scheduler) {

`

``

264

`+

this.scheduler = scheduler;

`

``

265

`+

return this;

`

``

266

`+

}

`

``

267

+

``

268

`+

/**

`

``

269

`+

`

``

270

`+

`

``

271

`+

`

``

272

`+

*/

`

``

273

`+

public Builder order(int order) {

`

``

274

`+

this.order = order;

`

``

275

`+

return this;

`

``

276

`+

}

`

``

277

+

``

278

`+

/**

`

``

279

`+

`

``

280

`+

`

``

281

`+

*/

`

195

282

`public VectorStoreChatMemoryAdvisor build() {

`

196

``

`-

return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,

`

197

``

`-

this.protectFromBlocking, this.systemPromptTemplate, this.order);

`

``

283

`+

return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.topK, this.conversationId,

`

``

284

`+

this.order, this.scheduler, this.vectorStore);

`

198

285

` }

`

199

286

``

200

287

` }

`