package com.infoepoch.pms.dispatchassistant.domain.langchain.chat.fastChat;



import com.infoepoch.pms.dispatchassistant.common.exception.ValidationException;
import com.infoepoch.pms.dispatchassistant.domain.langchain.chat.ChatBaseRequest;
import io.micrometer.core.instrument.util.StringUtils;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class FastChatRequest extends ChatBaseRequest {

    /**
     * 新增消息
     *
     * @param role
     * @param content
     */
    public void addMessage(String role, String content) {
        if (StringUtils.isBlank(role))
            throw new ValidationException("角色不可为空");
        if (StringUtils.isBlank(content))
            throw new ValidationException("内容不可为空");
        if (this.messages == null)
            this.messages = new ArrayList<>();
        Message message = new Message(
                role,
                content
        );
        this.messages.add(message);
    }

    /**
     * 校验
     */
    public void verify() {
        if (StringUtils.isBlank(this.model))
            throw new ValidationException("模型名称不可为空");
        if (this.messages == null || this.messages.isEmpty())
            throw new ValidationException("消息列表不可为空");
        if (this.temperature == null)
            throw new ValidationException("temperature不可为空");
        if (this.temperature.compareTo(BigDecimal.ZERO) < 0 || this.temperature.compareTo(BigDecimal.ONE) > 0)
            throw new ValidationException("temperature参数值不合法,请确保temperature在[0, 1]之间");
        if (this.n == null)
            throw new ValidationException("历史对话轮数不可为空");
        if (this.n < 0)
            throw new ValidationException("历史对话轮数不可小于0");
        if (this.maxTokens == null)
            throw new ValidationException("最大字数不可为空");
        if (this.maxTokens < 0)
            throw new ValidationException("最大字数不可小于0");
        if (this.stream == null)
            this.stream = false;
        if (this.presencePenalty == null)
            this.presencePenalty = 0;
        if (this.frequencyPenalty == null)
            this.frequencyPenalty = 0;
    }

    /**
     * 转换为入参MAP
     *
     * @return
     */
    public Map<String, Object> toMap() {
        this.verify();
        Map<String, Object> map = new HashMap<>();
        map.put("model", this.model);
        map.put("messages", this.messages);
        map.put("temperature", this.temperature);
        map.put("n", this.n);
        map.put("max_tokens", this.maxTokens);
        map.put("stop", this.stop);
        map.put("stream", this.stream);
        map.put("presence_penalty", this.presencePenalty);
        map.put("frequency_penalty", this.frequencyPenalty);
        return map;
    }

    /**
     * 私有化无参构造
     */
    private FastChatRequest() {
    }

    /**
     * 新增
     *
     * @param model
     * @param temperature
     * @param n
     * @param maxTokens
     * @param stop
     * @param stream
     * @param presencePenalty
     * @param frequencyPenalty
     */
    public FastChatRequest(String model, BigDecimal temperature, Integer n, Integer maxTokens, List<String> stop, Boolean stream, Integer presencePenalty, Integer frequencyPenalty) {
        this.model = model;
        this.temperature = temperature;
        this.n = n;
        this.maxTokens = maxTokens;
        this.stop = stop;
        this.stream = stream;
        this.presencePenalty = presencePenalty;
        this.frequencyPenalty = frequencyPenalty;
    }

    /**
     * 使用的 LLM 模型名称
     */
    private String model;
    /**
     * 对话列表
     */
    private List<Message> messages;
    /**
     *
     */
    private BigDecimal temperature;
    /**
     * 历史对话轮数
     */
    private Integer n;
    /**
     * 最大字数
     */
    private Integer maxTokens;
    /**
     *
     */
    private List<String> stop;
    /**
     * 是否启用流式输出
     */
    private Boolean stream;
    /**
     *
     */
    private Integer presencePenalty;
    /**
     *
     */
    private Integer frequencyPenalty;

    public String getModel() {
        return model;
    }

    public List<Message> getMessages() {
        return messages;
    }

    public BigDecimal getTemperature() {
        return temperature;
    }

    public Integer getN() {
        return n;
    }

    public Integer getMaxTokens() {
        return maxTokens;
    }

    public List<String> getStop() {
        return stop;
    }

    public Boolean getStream() {
        return stream;
    }

    public Integer getPresencePenalty() {
        return presencePenalty;
    }

    public Integer getFrequencyPenalty() {
        return frequencyPenalty;
    }

    private class Message {

        private Message() {
        }

        public Message(String role, String content) {
            this.role = role;
            this.content = content;
        }

        /**
         * 角色
         */
        private String role;
        /**
         * 内容
         */
        private String content;

        public String getRole() {
            return role;
        }

        public String getContent() {
            return content;
        }
    }

}
