package com.infoepoch.pms.dispatchassistant.controller.langchain;

import com.infoepoch.pms.dispatchassistant.common.component.Result;
import com.infoepoch.pms.dispatchassistant.common.exception.NotLoginException;
import com.infoepoch.pms.dispatchassistant.common.exception.ValidationException;
import com.infoepoch.pms.dispatchassistant.common.utils.JsonUtils;
import com.infoepoch.pms.dispatchassistant.controller.basic.Auth;
import com.infoepoch.pms.dispatchassistant.controller.system.DictTypeConstants;
import com.infoepoch.pms.dispatchassistant.domain.basic.user.User;
import com.infoepoch.pms.dispatchassistant.domain.langchain.ChatMessage;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.ChatResponse;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.ChatService;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.chat.ChatRequest;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.chat.ChatSuccessResponse;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.knowledgeBaseChat.KnowledgeBaseChatRequest;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.knowledgeBaseChat.KnowledgeBaseChatResponse;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.Conversation;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.ConversationCriteria;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.ConversationService;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.ConversationLine;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.ConversationLineCriteria;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.ConversationLineService;
import com.infoepoch.pms.dispatchassistant.domain.system.dict.DictDataService;
import com.infoepoch.pms.dispatchassistant.domain.system.dict.SystemDictData;
import io.micrometer.core.instrument.util.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;

@RestController
@RequestMapping("/mobile/langchain")
public class MobileLangChainController {

    @Autowired
    private Auth auth;

    @Autowired
    private ChatService chatService;
    @Autowired
    private ConversationService conversationService;
    @Autowired
    private ConversationLineService conversationLineService;
    @Autowired
    private DictDataService dictDataService;

    /**
     * 新增会话
     * @param conversation
     * @return
     * @throws NotLoginException
     */
    @PostMapping("/add-conversation")
    public Result addConversation(@RequestBody Conversation conversation) throws NotLoginException {
        if(StringUtils.isBlank(conversation.getType()))
            throw new ValidationException("会话类型不可为空");
        User user = auth.getMobileUserReq();
        conversationService.save(conversation, user);
        return Result.successData(conversation);
    }

    /**
     * 通用对话请求
     *
     * @param chatMessage
     * @return
     */
    @PostMapping("/common-chat")
    public Result commonChat(@RequestBody ChatMessage chatMessage) {
        // 请求AI服务接口
        ChatResponse response = null;
        if (StringUtils.isBlank(chatMessage.getType())) {
            throw new ValidationException("对话类型不可为空");
        }
        List<String> allowTypeList = Arrays.asList("chat", "knowledge", "fast");
        if (!allowTypeList.contains(chatMessage.getType())) {
            throw new ValidationException("对话类型不合法");
        }
        if (chatMessage.getType().equals("chat")) {
            ChatRequest request = new ChatRequest(
                    chatMessage.getMessage(),
                    chatMessage.getConversationId(),
                    false,
                    "chatglm3-6b",
                    new BigDecimal("0.7"),
                    0,
                    "default"
            );
            List<ConversationLine> historyList = conversationService.queryConversationHistory(chatMessage.getConversationId(), 3);
            for (ConversationLine line : historyList) {
                request.addHistory("user", line.getQuery());
                request.addHistory("assistant", line.getAnswer());
            }
            String result = chatService.chat(request);
            ChatSuccessResponse chatSuccessResponse = JsonUtils.jsonToObject(result, ChatSuccessResponse.class);
            response = new ChatResponse(
                    "chat",
                    chatSuccessResponse.getText(),
                    null
            );
            // 存储对话记录
            conversationService.saveLine(chatMessage, request, response, result);
        } else if (chatMessage.getType().equals("knowledge")) {
            KnowledgeBaseChatRequest request = new KnowledgeBaseChatRequest(
                    chatMessage.getMessage(),
                    "jiangsu",
                    3,
                    1,
                    false,
                    "chatglm3-6b",
                    new BigDecimal("0.7"),
                    0,
                    "default"
            );
            String result = chatService.knowledgeChat(request);
            KnowledgeBaseChatResponse knowledgeBaseChatResponse = JsonUtils.jsonToObject(result, KnowledgeBaseChatResponse.class);
            response = new ChatResponse(
                    "knowledge",
                    knowledgeBaseChatResponse.getAnswer(),
                    knowledgeBaseChatResponse.getDocs()
            );
            // 存储对话记录
            conversationService.saveLine(chatMessage, request, response, result);
        }
        // 返回结果
        if (response == null) {
            return Result.error(-1, "获取对话结果失败");
        }
        return Result.successData(response);
    }

    /**
     * 会话类型列表
     *
     * @return
     */
    @GetMapping("/type-list")
    public Result typeList() {
        // 根据字典类型查询系统字典数据
        List<SystemDictData> dictDataList = dictDataService.queryByDictType(DictTypeConstants.CONVERSATION_TYPE);
        // 返回查询结果
        return Result.successData(dictDataList);
    }

    /**
     * 会话列表
     *
     * @param type
     * @return
     */
    @GetMapping("/conversation-list/{type}")
    public Result conversationList(@PathVariable String type) throws NotLoginException {
        User user = auth.getMobileUserReq();
        ConversationCriteria criteria = new ConversationCriteria();
        criteria.setType(type);
        criteria.setUserId(user.getId());
        List<Conversation> conversationList = conversationService.list(criteria);
        return Result.successData(conversationList);
    }

    /**
     * 会话记录列表
     *
     * @param conversationId
     * @return
     */
    @GetMapping("/conversation-history/{conversationId}")
    public Result conversationHistory(@PathVariable String conversationId) {
        ConversationLineCriteria criteria = new ConversationLineCriteria();
        criteria.setConversationId(conversationId);
        List<ConversationLine> conversationHistory = conversationLineService.list(criteria);
        return Result.successData(conversationHistory);
    }

}
