package com.oxygenxml.smartautocomplete.core.openai;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.oxygenxml.smartautocomplete.core.CannotComputeCompletionDetailsException;
import com.oxygenxml.smartautocomplete.core.CompletionDetailsProvider;
import com.oxygenxml.smartautocomplete.core.Suggestion;
import com.oxygenxml.smartautocomplete.plugin.openai.OpenAIConstants;
import com.oxygenxml.smartautocomplete.plugin.openai.OpenAIFacade;
import java.io.IOException;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import org.apache.commons.lang3.StringUtils;
import org.apache.fontbox.ttf.NamingTable;
import org.apache.pdfbox.pdmodel.documentinterchange.taggedpdf.PDLayoutAttributeObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:oxygen-smart-autocomplete-addon-1.0.1-SNAPSHOT/lib/oxygen-smart-autocomplete-addon-1.0.1-SNAPSHOT.jar:com/oxygenxml/smartautocomplete/core/openai/OpenAICompletionDetailsProvider.class */
public class OpenAICompletionDetailsProvider extends OpenAIAPI implements CompletionDetailsProvider {
    private static final int DEFAULT_MAX_TOKENS_FOR_COMPLETION = 64;
    private static final String LOGPROBS = "logprobs";
    private static final String ENGINE = "engine";
    private static final String STOP = "stop";
    private static final String N = "n";
    private static final String MAX_TOKENS = "max_tokens";
    private static final String PROMPT = "prompt";
    private static final String USER = "user";
    private static final String TEMPERATURE = "temperature";
    private static final String TOP_P = "top_p";
    static final Logger logger = LoggerFactory.getLogger(OpenAICompletionDetailsProvider.class.getName());
    private OpenAIFacade openAIFacade;

    public OpenAICompletionDetailsProvider(OpenAIFacade openAIFacade) {
        this(openAIFacade, OpenAIAPI.DEFAULT_BASE_URL, openAIFacade);
    }

    public OpenAICompletionDetailsProvider(OpenAIFacade openAIFacade, String str, OpenAIKeyProvider openAIKeyProvider) {
        super(str, openAIKeyProvider);
        this.openAIFacade = openAIFacade;
    }

    @Override // com.oxygenxml.smartautocomplete.core.CompletionDetailsProvider
    public List<Suggestion> getSuggestions(String str) throws CannotComputeCompletionDetailsException {
        return getSuggestions(str, null);
    }

    public List<Suggestion> getSuggestions(String str, Map map) throws CannotComputeCompletionDetailsException {
        try {
            HttpURLConnection prepareRequest = prepareRequest(str, map);
            logger.debug("Request is ready, executing.");
            return filterUnsafeSuggestions((List) execute(prepareRequest, this::getCompletionDetailsFromResponse));
        } catch (IOException e) {
            throw new CannotComputeCompletionDetailsException(e.getMessage(), e);
        }
    }

    private List<Suggestion> filterUnsafeSuggestions(List<Suggestion> list) throws IOException {
        ArrayList arrayList = new ArrayList();
        for (Suggestion suggestion : list) {
            if (isSafeSuggestion(suggestion)) {
                arrayList.add(suggestion);
            }
        }
        return arrayList;
    }

    private boolean isSafeSuggestion(Suggestion suggestion) throws IOException {
        logger.debug("Checking if suggestion is safe.");
        return ((Boolean) execute(prepareContentFilterRequest(suggestion), this::isSafeFromResponse)).booleanValue();
    }

    HttpURLConnection prepareContentFilterRequest(Suggestion suggestion) throws IOException {
        HttpURLConnection createSimplePostRequest = createSimplePostRequest(this.baseUrl + "/v1/engines/content-filter-alpha/completions", true);
        writePayload(createSimplePostRequest, getContentFilterRequestPayload(suggestion));
        return createSimplePostRequest;
    }

    String getContentFilterRequestPayload(Suggestion suggestion) throws JsonProcessingException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(PROMPT, "<|endoftext|>" + String.join(StringUtils.SPACE, suggestion.getTokenList()) + "\n--\nLabel:");
        linkedHashMap.put(MAX_TOKENS, 1);
        linkedHashMap.put(TEMPERATURE, 0);
        linkedHashMap.put(TOP_P, 0);
        linkedHashMap.put(LOGPROBS, 10);
        linkedHashMap.put(USER, this.openAIFacade.getUserID());
        String writeValueAsString = this.objectMapper.writeValueAsString(linkedHashMap);
        logger.debug("The parameters for content filter {} ", linkedHashMap);
        return writeValueAsString;
    }

    private void writePayload(HttpURLConnection httpURLConnection, String str) throws IOException {
        httpURLConnection.setRequestProperty("Content-Type", "application/json");
        OutputStream outputStream = httpURLConnection.getOutputStream();
        outputStream.write(str.getBytes(StandardCharsets.UTF_8));
        outputStream.flush();
    }

    private List<Suggestion> getCompletionDetailsFromResponse(String str) throws JsonProcessingException {
        logger.debug("Completion response from OpenAI {}", str);
        Map<String, Object> map = (Map) this.objectMapper.readValue(str, Map.class);
        debugMap(map);
        ArrayList arrayList = new ArrayList();
        List list = (List) map.get("choices");
        if (list != null) {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                Suggestion suggestion = new Suggestion(1.0d, getTokens((String) ((HashMap) it.next()).get("text")));
                suggestion.setCanContainMarkup(false);
                arrayList.add(suggestion);
            }
        }
        return arrayList;
    }

    boolean isSafeFromResponse(String str) throws JsonProcessingException {
        boolean z = !"2".equals(getSafetyLabel(str));
        logger.debug("Is safe {}", Boolean.valueOf(z));
        return z;
    }

    String getSafetyLabel(String str) throws JsonProcessingException {
        logger.info("Content filter response from OpenAI {}", str);
        Map<String, Object> map = (Map) this.objectMapper.readValue(str, Map.class);
        debugMap(map);
        Map map2 = (Map) ((List) map.get("choices")).get(0);
        String str2 = (String) map2.get("text");
        logger.debug("Output label from content filter {}", str2);
        if ("2".equals(str2)) {
            Map map3 = (Map) ((List) ((Map) map2.get(LOGPROBS)).get("top_logprobs")).get(0);
            if (((Double) map3.get("2")).doubleValue() < -0.355f) {
                logger.debug("Is not quite toxic");
                Double d = (Double) map3.get(PDLayoutAttributeObject.GLYPH_ORIENTATION_VERTICAL_ZERO_DEGREES);
                Double d2 = (Double) map3.get("1");
                if (d != null && d2 != null) {
                    str2 = d.doubleValue() >= d2.doubleValue() ? PDLayoutAttributeObject.GLYPH_ORIENTATION_VERTICAL_ZERO_DEGREES : "1";
                } else if (d != null) {
                    str2 = PDLayoutAttributeObject.GLYPH_ORIENTATION_VERTICAL_ZERO_DEGREES;
                } else if (d2 != null) {
                    str2 = "1";
                }
            }
        }
        logger.debug("The output label {} ", str2);
        if (isOutsideSafetyRange(str2)) {
            str2 = "2";
        }
        return str2;
    }

    private boolean isOutsideSafetyRange(String str) {
        return ("2".equals(str) || "1".equals(str) || PDLayoutAttributeObject.GLYPH_ORIENTATION_VERTICAL_ZERO_DEGREES.equals(str)) ? false : true;
    }

    private List<String> getTokens(String str) {
        ArrayList arrayList = new ArrayList();
        StringTokenizer stringTokenizer = new StringTokenizer(str);
        while (stringTokenizer.hasMoreElements()) {
            arrayList.add(stringTokenizer.nextToken());
        }
        return arrayList;
    }

    HttpURLConnection prepareRequest(String str, Map<String, Object> map) throws IOException, CannotComputeCompletionDetailsException {
        HttpURLConnection createSimplePostRequest;
        Engine engine = getEngine(map);
        logger.debug("Using engine {}", engine);
        if (engine.isFineTune()) {
            createSimplePostRequest = createSimplePostRequest(this.baseUrl + "/v1/completions", true);
            logger.debug("Is a fine tune, using {}", createSimplePostRequest.getURL());
        } else {
            createSimplePostRequest = createSimplePostRequest(this.baseUrl + "/v1/engines/" + engine.getId() + "/completions", true);
            logger.debug("Is not a fine tune, using {}", createSimplePostRequest.getURL());
        }
        writePayload(createSimplePostRequest, getRequestPayload(str, map, engine));
        return createSimplePostRequest;
    }

    Engine getEngine(Map<String, Object> map) {
        HashMap hashMap = null;
        if (map != null) {
            hashMap = (HashMap) map.get(ENGINE);
            if (hashMap != null && hashMap.get(NamingTable.TAG) == null) {
                hashMap = null;
                logger.debug("The engine name from the parameters is null. Will use the default one.");
            }
        }
        return hashMap != null ? new Engine((String) hashMap.get(NamingTable.TAG), Boolean.TRUE.equals(hashMap.get("fine-tune"))) : this.openAIFacade.getSelectedEngine();
    }

    String getRequestPayload(String str, Map<String, Object> map, Engine engine) throws JsonProcessingException, CannotComputeCompletionDetailsException {
        int cappedMaxTokens = getCappedMaxTokens(map);
        checkLimit(str, cappedMaxTokens);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(PROMPT, str);
        if (engine.isFineTune()) {
            linkedHashMap.put("model", engine.getId());
        }
        linkedHashMap.put(MAX_TOKENS, Integer.valueOf(cappedMaxTokens));
        linkedHashMap.put("n", 1);
        linkedHashMap.put(STOP, Arrays.asList("."));
        linkedHashMap.put(USER, this.openAIFacade.getUserID());
        if (map != null) {
            LinkedHashMap linkedHashMap2 = new LinkedHashMap(map);
            linkedHashMap2.remove(MAX_TOKENS);
            linkedHashMap2.remove(ENGINE);
            linkedHashMap.putAll(linkedHashMap2);
        }
        debugMap(linkedHashMap);
        String writeValueAsString = this.objectMapper.writeValueAsString(linkedHashMap);
        logger.debug("The parameters {} ", linkedHashMap);
        return writeValueAsString;
    }

    static void checkLimit(String str, int i) throws CannotComputeCompletionDetailsException {
        checkIfExceedsMaxTokensLimit(str);
        checkTotalNumberOfTokens(str, i);
    }

    private static void checkIfExceedsMaxTokensLimit(String str) throws CannotComputeCompletionDetailsException {
        if (str.length() > 1000) {
            throw new CannotComputeCompletionDetailsException("The number of characters in the prefix is too large for a prompt: " + str.length() + " , maximum is: " + OpenAIConstants.MAX_NUMBER_OF_INPUT_CHARS_IN_PROMPT);
        }
    }

    static void checkTotalNumberOfTokens(String str, int i) throws CannotComputeCompletionDetailsException {
        int aproximateNumberOfTokens = getAproximateNumberOfTokens(str);
        if (aproximateNumberOfTokens + i > 583) {
            throw new CannotComputeCompletionDetailsException("The number of tokens in the selection is too large. (" + aproximateNumberOfTokens + " for prompt, " + i + " for completion, " + OpenAIConstants.NUMBER_OF_TOKENS_IN_BOTH_PROMPT_AND_COMPLETION + " maximum allowed)");
        }
    }

    private static int getCappedMaxTokens(Map<String, Object> map) {
        Integer num;
        int i = 64;
        if (map != null && (num = (Integer) map.get(MAX_TOKENS)) != null) {
            i = num.intValue();
        }
        if (i > 250) {
            logger.debug("Capped the number of completion tokens from {}, to {}", Integer.valueOf(i), Integer.valueOf(OpenAIConstants.MAX_NUMBER_OF_TOKENS_IN_COMPLETION));
            i = 250;
        }
        return i;
    }

    static int getAproximateNumberOfTokens(String str) {
        return str.length() / 3;
    }
}
