useChat.ts 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import { useXAgent, XStream } from "@ant-design/x";
  2. import { useEffect, useRef, useState } from "react";
  3. import { useSessionStorageState } from "ahooks";
  4. import { GetSessionList, GetSessionMessageList } from "@/api/ai";
  5. import type { ConversationsProps } from "@ant-design/x";
  6. import type { ReactNode } from "react";
  7. // 消息格式
  8. type MessageItem = {
  9. id: string;
  10. content: string | ReactNode;
  11. role: "user" | "assistant" | "system";
  12. status: "loading" | "done" | "error" | "stop";
  13. loading?: boolean;
  14. footer?: ReactNode;
  15. };
  16. // 后端返回格式
  17. type ResponseMessageItem = {
  18. answer: string;
  19. conversation_id: string;
  20. created_at: number;
  21. event: "message" | "message_end" | "message_error" | "ping";
  22. message_id: string;
  23. task_id: string;
  24. };
  25. type ChatParams = {
  26. // 应用名称
  27. app_name: string;
  28. // 会话内容
  29. chat_query: string;
  30. // 会话名称 第一次
  31. chat_name?: string;
  32. // 会话id 后续会话带入
  33. conversation_id?: string;
  34. };
  35. type ChatProps = {
  36. // 应用名称
  37. app_name: string;
  38. // 会话id 后续会话带入
  39. conversation_id?: string;
  40. // 成功获取会话内容
  41. onSuccess?: (data: ResponseMessageItem) => void;
  42. // 更新流式消息内容
  43. onUpdate: (data: ResponseMessageItem) => void;
  44. // 异常
  45. onError?: (error: Error) => void;
  46. };
  47. const defaultConversation = {
  48. // 会话id
  49. key: "1",
  50. label: "新的对话",
  51. };
  52. export function useChat({ app_name, onSuccess, onUpdate, onError }: ChatProps) {
  53. /**
  54. * 发送消息加载状态
  55. */
  56. const [loading, setLoading] = useState(false);
  57. /**
  58. * 加载会话记录列表
  59. */
  60. const [loadingSession, setLoadingSession] = useState(false);
  61. /**
  62. * 加载消息列表
  63. */
  64. const [loadingMessages, setLoadingMessages] = useState(false);
  65. // 用于停止对话
  66. const abortController = useRef<AbortController | null>(null);
  67. /**
  68. * 消息列表
  69. */
  70. const [messages, setMessages] = useState<Array<MessageItem>>([]);
  71. // 会话列表
  72. const [conversationList, setConversationList] = useState<
  73. ConversationsProps["items"]
  74. >([{ ...defaultConversation }]);
  75. // 活动对话
  76. const [activeConversation, setActiveConversation] = useState("1");
  77. // 当前智能体对象
  78. const [currentAgent, setCurrentAgent] = useSessionStorageState("agent-map");
  79. useEffect(() => {
  80. setLoadingSession(true);
  81. GetSessionList({
  82. app_name,
  83. page_index: 1,
  84. })
  85. .then((res) => {
  86. setConversationList([
  87. { ...defaultConversation },
  88. ...(res?.result?.model || []).map((item: any) => ({
  89. ...item,
  90. key: item.sessionId,
  91. label: item.name,
  92. })),
  93. ]);
  94. })
  95. .finally(() => {
  96. setLoadingSession(false);
  97. });
  98. }, [app_name]);
  99. /**
  100. * 切换会话
  101. * @param key 会话id
  102. * @returns
  103. */
  104. const changeConversation = async (key: string) => {
  105. setActiveConversation(key);
  106. if (key === "1") {
  107. setMessages([]);
  108. return;
  109. }
  110. setLoadingMessages(true);
  111. // 获取会话内容
  112. try {
  113. const res = await GetSessionMessageList({
  114. app_name,
  115. session_id: key,
  116. page_index: 1,
  117. });
  118. const list: MessageItem[] = [];
  119. (res?.result?.model || []).forEach((item: any) => {
  120. list.push(
  121. {
  122. id: item.id + "_query",
  123. content: item.query,
  124. role: "user",
  125. status: "done",
  126. },
  127. {
  128. id: item.id + "_query",
  129. content: item.answer,
  130. role: "assistant",
  131. status: "done",
  132. }
  133. );
  134. });
  135. setMessages(list);
  136. } finally {
  137. setLoadingMessages(false);
  138. }
  139. };
  140. /**
  141. * 封装智能体
  142. */
  143. const [agent] = useXAgent<ResponseMessageItem>({
  144. request: async (message, { onError, onSuccess, onUpdate }) => {
  145. abortController.current = new AbortController();
  146. const signal = abortController.current.signal;
  147. try {
  148. setLoading(true);
  149. const response = await fetch(
  150. "https://design.shalu.com/api/ai/chat-message",
  151. {
  152. method: "POST",
  153. body: JSON.stringify(message),
  154. headers: {
  155. Authorization: localStorage.getItem("token_a") || "",
  156. "Content-Type": "application/json",
  157. },
  158. signal,
  159. }
  160. );
  161. // 判断当前是否流式返回
  162. if(response.headers.get('content-type')?.includes('text/event-stream')) {
  163. if (response.body) {
  164. for await (const chunk of XStream({
  165. readableStream: response.body,
  166. })) {
  167. const data = JSON.parse(chunk.data);
  168. if (data?.event === "message") {
  169. onUpdate(data);
  170. }
  171. if (data?.event === "message_end") {
  172. onSuccess(data);
  173. }
  174. if (data?.event === "message_error") {
  175. onError(data);
  176. }
  177. if (data?.event === "ping") {
  178. console.log(">>>> stream start <<<<");
  179. }
  180. }
  181. }
  182. } else {
  183. // 接口异常处理
  184. response.json().then(res => {
  185. if(res.code === 0 ) {
  186. onError?.(Error(res?.error || '请求失败'));
  187. cancel();
  188. }
  189. });
  190. }
  191. } catch (error) {
  192. // 判断是不是 abort 错误
  193. if (signal.aborted) {
  194. return;
  195. }
  196. onError(error as Error);
  197. } finally {
  198. setLoading(false);
  199. }
  200. },
  201. });
  202. /**
  203. * 发起请求
  204. * @param chat_query 对话内容
  205. */
  206. const onRequest = (chat_query: string) => {
  207. setConversationList((list) => {
  208. return list?.map((item) => {
  209. return {
  210. ...item,
  211. label: item.key === "1" ? chat_query : item.label,
  212. };
  213. });
  214. });
  215. agent.request(
  216. {
  217. app_name,
  218. chat_query,
  219. chat_name: activeConversation === "1" ? chat_query : undefined,
  220. conversation_id:
  221. activeConversation === "1" ? undefined : activeConversation,
  222. },
  223. {
  224. onSuccess: (data) => {
  225. onSuccess?.(data);
  226. },
  227. onUpdate: (data) => {
  228. onUpdate(data);
  229. // 更新会话相关信息
  230. if (activeConversation === "1") {
  231. setConversationList((list) => {
  232. return list?.map((item) => {
  233. return {
  234. ...item,
  235. // 更新当前会话id
  236. key: item.key === "1" ? data.conversation_id : item.key,
  237. };
  238. });
  239. });
  240. setActiveConversation(data.conversation_id);
  241. }
  242. },
  243. onError: (error) => {
  244. console.log("error", error);
  245. onError?.(error);
  246. },
  247. }
  248. );
  249. };
  250. /**
  251. * 停止对话
  252. */
  253. const cancel = () => {
  254. abortController.current?.abort();
  255. };
  256. /**
  257. * 新增会话
  258. */
  259. const addConversation = () => {
  260. setMessages([]);
  261. setActiveConversation("1");
  262. // 还没产生对话时 直接清除当前对话
  263. if (!conversationList?.find((item) => item.key === "1")) {
  264. setConversationList([
  265. {
  266. ...defaultConversation,
  267. },
  268. ...(conversationList || []),
  269. ]);
  270. }
  271. };
  272. return {
  273. agent,
  274. loading,
  275. loadingMessages,
  276. loadingSession,
  277. cancel,
  278. messages,
  279. setMessages,
  280. conversationList,
  281. setConversationList,
  282. activeConversation,
  283. setActiveConversation,
  284. onRequest,
  285. addConversation,
  286. changeConversation,
  287. };
  288. }