package com.infoepoch.pms.dispatchassistant.infractructure.langchain;

import com.infoepoch.pms.dispatchassistant.common.utils.OracleUtils;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.ConversationLine;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.ConversationLineCriteria;
import com.infoepoch.pms.dispatchassistant.domain.langchain.record.conversation.line.IConversationLineRepository;
import io.micrometer.core.instrument.util.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.support.rowset.SqlRowSet;
import org.springframework.stereotype.Repository;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Repository
public class ConversationLineRepository implements IConversationLineRepository {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insert(ConversationLine entity) {
        String sql = "INSERT INTO AI_CONVERSATION_LINE(CL_ID, CL_CONVERSATION_ID, CL_TYPE, CL_QUERY, CL_STREAM, CL_MODEL_NAME, CL_TEMPERATURE, CL_MAX_TOKENS, CL_PROMPT_NAME, CL_HISTORY_LEN, CL_N, CL_STOP, CL_PRESENCE_PENALTY, CL_FREQUENCY_PENALTY, CL_KNOWLEDGE_BASE_NAME, CL_TOP_K, CL_SCORE_THRESHOLD, CL_SEARCH_ENGINE_NAME, CL_CREATE_TIME, CL_RESULT, CL_ERROR_FLAG, CL_ERROR_DETAIL, CL_ANSWER) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
        jdbcTemplate.update(sql, entity.getId(), entity.getConversationId(), entity.getType(), entity.getQuery(), entity.getStream(),
                entity.getModelName(), entity.getTemperature(), entity.getMaxTokens(), entity.getPromptName(), entity.getHistoryLen(),
                entity.getN(), entity.getStop(), entity.getPresencePenalty(), entity.getFrequencyPenalty(), entity.getKnowledgeBaseName(),
                entity.getTopK(), entity.getScoreThreshold(), entity.getSearchEngineName(), entity.getCreateTime(), entity.getResult(),
                entity.getErrorFlag(), entity.getErrorDetail(), entity.getAnswer());
    }

    @Override
    public void batchInsert(List<ConversationLine> list) {
        String sql = "INSERT INTO AI_CONVERSATION_LINE(CL_ID, CL_CONVERSATION_ID, CL_TYPE, CL_QUERY, CL_STREAM, CL_MODEL_NAME, CL_TEMPERATURE, CL_MAX_TOKENS, CL_PROMPT_NAME, CL_HISTORY_LEN, CL_N, CL_STOP, CL_PRESENCE_PENALTY, CL_FREQUENCY_PENALTY, CL_KNOWLEDGE_BASE_NAME, CL_TOP_K, CL_SCORE_THRESHOLD, CL_SEARCH_ENGINE_NAME, CL_CREATE_TIME, CL_RESULT, CL_ERROR_FLAG, CL_ERROR_DETAIL, CL_ANSWER) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
        jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
            @Override
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                ConversationLine line = list.get(i);
                int j = 0;
                ps.setString(++j, line.getId());
                ps.setString(++j, line.getConversationId());
                ps.setString(++j, line.getType());
                ps.setString(++j, line.getQuery());
                ps.setObject(++j, line.getStream());
                ps.setString(++j, line.getModelName());
                ps.setBigDecimal(++j, line.getTemperature());
                ps.setObject(++j, line.getMaxTokens());
                ps.setString(++j, line.getPromptName());
                ps.setObject(++j, line.getHistoryLen());
                ps.setObject(++j, line.getN());
                ps.setString(++j, line.getStop());
                ps.setObject(++j, line.getPresencePenalty());
                ps.setObject(++j, line.getFrequencyPenalty());
                ps.setString(++j, line.getKnowledgeBaseName());
                ps.setObject(++j, line.getTopK());
                ps.setObject(++j, line.getScoreThreshold());
                ps.setString(++j, line.getSearchEngineName());
                ps.setTimestamp(++j, line.getCreateTime() == null ? null : new Timestamp(line.getCreateTime().getTime()));
                ps.setString(++j, line.getResult());
                ps.setBoolean(++j, line.getErrorFlag());
                ps.setString(++j, line.getErrorDetail());
                ps.setString(++j, line.getAnswer());
            }

            @Override
            public int getBatchSize() {
                return list.size();
            }
        });
    }

    @Override
    public void update(ConversationLine entity) {
        String sql = "UPDATE AI_CONVERSATION_LINE SET CL_CONVERSATION_ID = ?, CL_TYPE = ?, CL_QUERY = ?, CL_STREAM = ?, CL_MODEL_NAME = ?, CL_TEMPERATURE = ?, CL_MAX_TOKENS = ?, CL_PROMPT_NAME = ?, CL_HISTORY_LEN = ?, CL_N = ?, CL_STOP = ?, CL_PRESENCE_PENALTY = ?, CL_FREQUENCY_PENALTY = ?, CL_KNOWLEDGE_BASE_NAME = ?, CL_TOP_K = ?, CL_SCORE_THRESHOLD = ?, CL_SEARCH_ENGINE_NAME = ?, CL_CREATE_TIME = ?, CL_RESULT = ?, CL_ERROR_FLAG = ?, CL_ERROR_DETAIL = ?, CL_ANSWER = ? WHERE CL_ID = ?";
        jdbcTemplate.update(sql, entity.getConversationId(), entity.getType(), entity.getQuery(), entity.getStream(),
                entity.getModelName(), entity.getTemperature(), entity.getMaxTokens(), entity.getPromptName(), entity.getHistoryLen(),
                entity.getN(), entity.getStop(), entity.getPresencePenalty(), entity.getFrequencyPenalty(), entity.getKnowledgeBaseName(),
                entity.getTopK(), entity.getScoreThreshold(), entity.getSearchEngineName(), entity.getCreateTime(), entity.getResult(),
                entity.getErrorFlag(), entity.getErrorDetail(), entity.getAnswer(), entity.getId());
    }

    @Override
    public void batchUpdate(List<ConversationLine> list) {
        String sql = "UPDATE AI_CONVERSATION_LINE SET CL_CONVERSATION_ID = ?, CL_TYPE = ?, CL_QUERY = ?, CL_STREAM = ?, CL_MODEL_NAME = ?, CL_TEMPERATURE = ?, CL_MAX_TOKENS = ?, CL_PROMPT_NAME = ?, CL_HISTORY_LEN = ?, CL_N = ?, CL_STOP = ?, CL_PRESENCE_PENALTY = ?, CL_FREQUENCY_PENALTY = ?, CL_KNOWLEDGE_BASE_NAME = ?, CL_TOP_K = ?, CL_SCORE_THRESHOLD = ?, CL_SEARCH_ENGINE_NAME = ?, CL_CREATE_TIME = ?, CL_RESULT = ?, CL_ERROR_FLAG = ?, CL_ERROR_DETAIL = ?, CL_ANSWER = ? WHERE CL_ID = ?";
        jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
            @Override
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                ConversationLine line = list.get(i);
                int j = 0;
                ps.setString(++j, line.getConversationId());
                ps.setString(++j, line.getType());
                ps.setString(++j, line.getQuery());
                ps.setObject(++j, line.getStream());
                ps.setString(++j, line.getModelName());
                ps.setBigDecimal(++j, line.getTemperature());
                ps.setObject(++j, line.getMaxTokens());
                ps.setString(++j, line.getPromptName());
                ps.setObject(++j, line.getHistoryLen());
                ps.setObject(++j, line.getN());
                ps.setString(++j, line.getStop());
                ps.setObject(++j, line.getPresencePenalty());
                ps.setObject(++j, line.getFrequencyPenalty());
                ps.setString(++j, line.getKnowledgeBaseName());
                ps.setObject(++j, line.getTopK());
                ps.setObject(++j, line.getScoreThreshold());
                ps.setString(++j, line.getSearchEngineName());
                ps.setTimestamp(++j, line.getCreateTime() == null ? null : new Timestamp(line.getCreateTime().getTime()));
                ps.setString(++j, line.getResult());
                ps.setBoolean(++j, line.getErrorFlag());
                ps.setString(++j, line.getErrorDetail());
                ps.setString(++j, line.getAnswer());
                ps.setString(++j, line.getId());
            }

            @Override
            public int getBatchSize() {
                return list.size();
            }
        });
    }

    @Override
    public ConversationLine selectById(String id) {
        SqlRowSet sqlRowSet = jdbcTemplate.queryForRowSet("SELECT * FROM AI_CONVERSATION_LINE WHERE CL_ID = ?", id);
        if (sqlRowSet.next())
            return convertRowSet(sqlRowSet);
        return null;
    }

    @Override
    public List<ConversationLine> selectByCriteria(ConversationLineCriteria criteria) {
        StringBuffer sql = new StringBuffer("SELECT * FROM AI_CONVERSATION_LINE");
        List<Object> params = OracleUtils.combinationSql(sql, createAndMap(criteria));
        SqlRowSet sqlRowSet = jdbcTemplate.queryForRowSet(sql.toString(), params.toArray());
        List<ConversationLine> list = new ArrayList<>();
        while (sqlRowSet.next()) {
            list.add(convertRowSet(sqlRowSet));
        }
        return list;
    }

    @Override
    public List<ConversationLine> selectByCriteriaPage(ConversationLineCriteria criteria, int pageIndex, int pageSize) {
        StringBuffer sql = new StringBuffer("SELECT * FROM AI_CONVERSATION_LINE");
        List<Object> params = OracleUtils.combinationSql(sql, createAndMap(criteria), pageIndex, pageSize);
        SqlRowSet sqlRowSet = jdbcTemplate.queryForRowSet(sql.toString(), params.toArray());
        List<ConversationLine> list = new ArrayList<>();
        while (sqlRowSet.next()) {
            list.add(convertRowSet(sqlRowSet));
        }
        return list;
    }

    @Override
    public int selectByCriteriaCount(ConversationLineCriteria criteria) {
        StringBuffer sql = new StringBuffer("SELECT COUNT(1) FROM AI_CONVERSATION_LINE");
        List<Object> params = OracleUtils.combinationSql(sql, createAndMap(criteria));
        return jdbcTemplate.queryForObject(sql.toString(), params.toArray(), int.class);
    }

    private ConversationLine convertRowSet(SqlRowSet rowSet) {
        return new ConversationLine(
                rowSet.getString("CL_ID"),
                rowSet.getString("CL_CONVERSATION_ID"),
                rowSet.getString("CL_TYPE"),
                rowSet.getString("CL_QUERY"),
                rowSet.getObject("CL_STREAM") == null ? null : rowSet.getBoolean("CL_STREAM"),
                rowSet.getString("CL_MODEL_NAME"),
                rowSet.getBigDecimal("CL_TEMPERATURE"),
                rowSet.getObject("CL_MAX_TOKENS") == null ? null : rowSet.getInt("CL_MAX_TOKENS"),
                rowSet.getString("CL_PROMPT_NAME"),
                rowSet.getObject("CL_HISTORY_LEN") == null ? null : rowSet.getInt("CL_HISTORY_LEN"),
                rowSet.getObject("CL_N") == null ? null : rowSet.getInt("CL_N"),
                rowSet.getString("CL_STOP"),
                rowSet.getObject("CL_PRESENCE_PENALTY") == null ? null : rowSet.getInt("CL_PRESENCE_PENALTY"),
                rowSet.getObject("CL_FREQUENCY_PENALTY") == null ? null : rowSet.getInt("CL_FREQUENCY_PENALTY"),
                rowSet.getString("CL_KNOWLEDGE_BASE_NAME"),
                rowSet.getObject("CL_TOP_K") == null ? null : rowSet.getInt("CL_TOP_K"),
                rowSet.getObject("CL_SCORE_THRESHOLD") == null ? null : rowSet.getInt("CL_SCORE_THRESHOLD"),
                rowSet.getString("CL_SEARCH_ENGINE_NAME"),
                rowSet.getTimestamp("CL_CREATE_TIME"),
                rowSet.getString("CL_RESULT"),
                rowSet.getObject("CL_ERROR_FLAG") == null ? null : rowSet.getBoolean("CL_ERROR_FLAG"),
                rowSet.getString("CL_ERROR_DETAIL"),
                rowSet.getString("CL_ANSWER")
        );
    }

    private Map<String, Object> createAndMap(ConversationLineCriteria criteria) {
        Map<String, Object> andMap = new HashMap<>();

        // 会话ID
        if(StringUtils.isNotBlank(criteria.getConversationId())) {
            andMap.put(" CL_CONVERSATION_ID = ? ", criteria.getConversationId());
        }

        return andMap;
    }

}
