/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.Validate;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;
import org.opensearch.neuralsearch.processor.normalization.bounds.BoundMode;
import org.opensearch.neuralsearch.processor.normalization.bounds.LowerBound;
import org.opensearch.neuralsearch.processor.normalization.bounds.UpperBound;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;

public class MinMaxScoreNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "min_max";
    protected static final float MIN_SCORE = 0.001f;
    protected static final float MAX_SCORE = 1.0f;
    private static final float SINGLE_RESULT_SCORE = 1.0f;
    private static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds";
    private static final String PARAM_NAME_BOUND_MODE = "mode";
    private static final String PARAM_NAME_LOWER_BOUND_MIN_SCORE = "min_score";
    private static final String PARAM_NAME_UPPER_BOUNDS = "upper_bounds";
    private static final String PARAM_NAME_UPPER_BOUND_MAX_SCORE = "max_score";
    private static final Set<String> SUPPORTED_PARAMETERS = Set.of("lower_bounds", "upper_bounds");
    private static final Map<String, Set<String>> NESTED_PARAMETERS = Map.of("lower_bounds", Set.of("mode", "min_score"), "upper_bounds", Set.of("mode", "max_score"));
    private final Optional<List<Map<String, Object>>> lowerBoundsParamsOptional;
    private final Optional<List<Map<String, Object>>> upperBoundsParamsOptional;

    public MinMaxScoreNormalizationTechnique() {
        this(Map.of(), new ScoreNormalizationUtil());
    }

    public MinMaxScoreNormalizationTechnique(Map<String, Object> params, ScoreNormalizationUtil scoreNormalizationUtil) {
        scoreNormalizationUtil.validateParameters(params, SUPPORTED_PARAMETERS, NESTED_PARAMETERS);
        this.lowerBoundsParamsOptional = this.getBoundsParams(params, PARAM_NAME_LOWER_BOUNDS);
        this.upperBoundsParamsOptional = this.getBoundsParams(params, PARAM_NAME_UPPER_BOUNDS);
    }

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            if (this.isBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "expected bounds array to contain %d elements matching the number of sub-queries, but found a mismatch", topDocsPerSubQuery.size()));
            }
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                LowerBound lowerBound = this.getLowerBound(j);
                UpperBound upperBound = this.getUpperBound(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[j], minMaxScores.getMaxScoresPerSubquery()[j], lowerBound, upperBound);
                }
            }
        }
    }

    private boolean isBoundsAndSubQueriesCountMismatched(List<TopDocs> topDocsPerSubQuery) {
        boolean lowerBoundsMismatch = this.lowerBoundsParamsOptional.isPresent() && !topDocsPerSubQuery.isEmpty() && this.lowerBoundsParamsOptional.get().size() != topDocsPerSubQuery.size();
        boolean upperBoundsMismatch = this.upperBoundsParamsOptional.isPresent() && !topDocsPerSubQuery.isEmpty() && this.upperBoundsParamsOptional.get().size() != topDocsPerSubQuery.size();
        return lowerBoundsMismatch || upperBoundsMismatch;
    }

    private LowerBound getLowerBound(int subQueryIndex) {
        return this.lowerBoundsParamsOptional.map(bounds -> (Map)bounds.get(subQueryIndex)).map(LowerBound::new).orElseGet(LowerBound::new);
    }

    private UpperBound getUpperBound(int subQueryIndex) {
        return this.upperBoundsParamsOptional.map(bounds -> (Map)bounds.get(subQueryIndex)).map(UpperBound::new).orElseGet(UpperBound::new);
    }

    private MinMaxScores getMinMaxScoresResult(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = ProcessorUtils.getNumOfSubqueries(queryTopDocs);
        float[] minScoresPerSubquery = this.getMinScores(queryTopDocs, numOfSubqueries);
        float[] maxScoresPerSubquery = this.getMaxScores(queryTopDocs, numOfSubqueries);
        return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery);
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public String describe() {
        StringBuilder description = new StringBuilder(TECHNIQUE_NAME);
        description.append(this.buildBoundDescription(this.lowerBoundsParamsOptional, "lower", 0.0));
        description.append(this.buildBoundDescription(this.upperBoundsParamsOptional, "upper", 1.0));
        return description.toString();
    }

    private String buildBoundDescription(Optional<List<Map<String, Object>>> boundsOptional, String label, double defaultScore) {
        return boundsOptional.map(bounds -> {
            String formatted = bounds.stream().map(boundMap -> {
                BoundMode mode = BoundMode.fromString(Objects.toString(boundMap.get(PARAM_NAME_BOUND_MODE), ""));
                String score = Objects.toString(boundMap.get(label.equals("lower") ? PARAM_NAME_LOWER_BOUND_MIN_SCORE : PARAM_NAME_UPPER_BOUND_MAX_SCORE), String.valueOf(defaultScore));
                return String.format(Locale.ROOT, "(%s, %s)", new Object[]{mode, score});
            }).collect(Collectors.joining(", ", "[", "]"));
            return String.format(Locale.ROOT, ", %s bounds %s", label, formatted);
        }).orElse("");
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; ++subQueryIndex) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
                    LowerBound lowerBound = this.getLowerBound(subQueryIndex);
                    UpperBound upperBound = this.getUpperBound(subQueryIndex);
                    float normalizedScore = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[subQueryIndex], minMaxScores.getMaxScoresPerSubquery()[subQueryIndex], lowerBound, upperBound);
                    ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docIdAtSearchShard, subQueryIndex, numberOfSubQueries, normalizedScore);
                    scoreDoc.score = normalizedScore;
                }
            }
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private float[] getMaxScores(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        float[] maxScores = new float[numOfSubqueries];
        Arrays.fill(maxScores, Float.MIN_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                maxScores[j] = Math.max(maxScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).orElse(Float.valueOf(Float.MIN_VALUE)).floatValue());
            }
        }
        return maxScores;
    }

    private float[] getMinScores(List<CompoundTopDocs> queryTopDocs, int numOfScores) {
        float[] minScores = new float[numOfScores];
        Arrays.fill(minScores, Float.MAX_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                minScores[j] = Math.min(minScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).min(Float::compare).orElse(Float.valueOf(Float.MAX_VALUE)).floatValue());
            }
        }
        return minScores;
    }

    private float normalizeSingleScore(float score, float minScore, float maxScore, LowerBound lowerBound, UpperBound upperBound) {
        if (this.isSingleScore(score, minScore, maxScore)) {
            return 1.0f;
        }
        float effectiveMinScore = lowerBound.determineEffectiveScore(score, minScore, maxScore);
        float effectiveMaxScore = upperBound.determineEffectiveScore(score, minScore, maxScore);
        if (lowerBound.shouldClipToBound(score, effectiveMinScore)) {
            return 0.001f;
        }
        if (upperBound.shouldClipToBound(score, effectiveMaxScore)) {
            return 1.0f;
        }
        return this.calculateNormalizedScore(score, effectiveMinScore, effectiveMaxScore);
    }

    private boolean isSingleScore(float score, float minScore, float maxScore) {
        return Floats.compare((float)maxScore, (float)minScore) == 0 && Floats.compare((float)maxScore, (float)score) == 0;
    }

    @VisibleForTesting
    protected float calculateNormalizedScore(float score, float effectiveMinScore, float effectiveMaxScore) {
        if (Floats.compare((float)effectiveMaxScore, (float)effectiveMinScore) == 0) {
            return 1.0f;
        }
        float normalizedScore = (score - effectiveMinScore) / (effectiveMaxScore - effectiveMinScore);
        return normalizedScore == 0.0f ? 0.001f : normalizedScore;
    }

    @VisibleForTesting
    protected Optional<List<Map<String, Object>>> getBoundsParams(Map<String, Object> params, String paramName) {
        String scoreParamName;
        if (Objects.isNull(params) || !params.containsKey(paramName)) {
            return Optional.empty();
        }
        List boundsParams = Optional.ofNullable(params.get(paramName)).filter(List.class::isInstance).map(List.class::cast).orElseThrow(() -> new IllegalArgumentException(paramName + " must be a List"));
        if (boundsParams.size() > 5) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s size %d should be less than or equal to %d", paramName, boundsParams.size(), 5));
        }
        return Optional.of(boundsParams.stream().map(arg_0 -> this.lambda$getBoundsParams$7(scoreParamName, switch (paramName) {
            case PARAM_NAME_LOWER_BOUNDS -> {
                scoreParamName = PARAM_NAME_LOWER_BOUND_MIN_SCORE;
                yield 0.0f;
            }
            case PARAM_NAME_UPPER_BOUNDS -> {
                scoreParamName = PARAM_NAME_UPPER_BOUND_MAX_SCORE;
                yield 1.0f;
            }
            default -> throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported bounds parameter name: %s", paramName));
        }, arg_0)).collect(Collectors.toList()));
    }

    private void validateBoundScore(Map<String, Object> bound, String scoreParamName, float defaultScore) {
        Object scoreObj = bound.get(scoreParamName);
        if (scoreObj == null) {
            return;
        }
        try {
            float score = Float.parseFloat(String.valueOf(scoreObj));
            Validate.isTrue((score >= -10000.0f && score <= 10000.0f ? 1 : 0) != 0, (String)"%s must be a valid finite number between %f and %f", (Object[])new Object[]{scoreParamName, Float.valueOf(-10000.0f), Float.valueOf(10000.0f)});
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "invalid format for %s: must be a valid float value", scoreParamName), e);
        }
    }

    @Generated
    public String toString() {
        return "MinMaxScoreNormalizationTechnique(TECHNIQUE_NAME=min_max)";
    }

    private /* synthetic */ Map lambda$getBoundsParams$7(String scoreParamName, float defaultScore, Object item) {
        if (!(item instanceof Map)) {
            throw new IllegalArgumentException("each bound must be a map");
        }
        Map boundMap = (Map)item;
        String modeString = Objects.toString(boundMap.get(PARAM_NAME_BOUND_MODE), "");
        if (!modeString.isEmpty()) {
            BoundMode.fromString(modeString);
        }
        this.validateBoundScore(boundMap, scoreParamName, defaultScore);
        return boundMap;
    }

    private static class MinMaxScores {
        float[] minScoresPerSubquery;
        float[] maxScoresPerSubquery;

        @Generated
        public MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) {
            this.minScoresPerSubquery = minScoresPerSubquery;
            this.maxScoresPerSubquery = maxScoresPerSubquery;
        }

        @Generated
        public float[] getMinScoresPerSubquery() {
            return this.minScoresPerSubquery;
        }

        @Generated
        public float[] getMaxScoresPerSubquery() {
            return this.maxScoresPerSubquery;
        }
    }
}

