/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.algorithm.seismic;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.opensearch.common.Randomness;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.algorithm.ClusteringAlgorithm;
import org.opensearch.neuralsearch.sparse.algorithm.PostingsProcessingUtils;
import org.opensearch.neuralsearch.sparse.data.DocWeight;
import org.opensearch.neuralsearch.sparse.data.DocumentCluster;
import org.opensearch.neuralsearch.sparse.data.SparseVector;

public class RandomClusteringAlgorithm
implements ClusteringAlgorithm {
    private final float summaryPruneRatio;
    private final float clusterRatio;
    @NonNull
    private final SparseVectorReader reader;

    @Override
    public List<DocumentCluster> cluster(List<DocWeight> docWeights) throws IOException {
        if (CollectionUtils.isEmpty(docWeights)) {
            return Collections.emptyList();
        }
        if (this.clusterRatio == 0.0f) {
            DocumentCluster cluster = new DocumentCluster(null, docWeights, true);
            return List.of(cluster);
        }
        int size = docWeights.size();
        int numCluster = Math.min(size, Math.max(1, (int)Math.ceil((float)size * this.clusterRatio)));
        int[] centers = Randomness.get().ints(0, size).distinct().limit(numCluster).toArray();
        ArrayList docAssignments = new ArrayList(numCluster);
        ArrayList<SparseVector> sparseVectors = new ArrayList<SparseVector>();
        for (int i = 0; i < numCluster; ++i) {
            docAssignments.add(new ArrayList());
            SparseVector center = this.reader.read(docWeights.get(centers[i]).getDocID());
            sparseVectors.add(center);
        }
        for (DocWeight docWeight : docWeights) {
            int centerIdx = 0;
            float maxScore = Float.MIN_VALUE;
            SparseVector docVector = this.reader.read(docWeight.getDocID());
            if (docVector == null) continue;
            byte[] denseDocVector = docVector.toDenseVector();
            for (int i = 0; i < numCluster; ++i) {
                float score = Float.MIN_VALUE;
                SparseVector center = (SparseVector)sparseVectors.get(i);
                if (center != null) {
                    score = center.dotProduct(denseDocVector);
                }
                if (!(score > maxScore)) continue;
                maxScore = score;
                centerIdx = i;
            }
            ((List)docAssignments.get(centerIdx)).add(docWeight);
        }
        ArrayList<DocumentCluster> clusters = new ArrayList<DocumentCluster>();
        for (int i = 0; i < numCluster; ++i) {
            if (((List)docAssignments.get(i)).isEmpty()) continue;
            DocumentCluster cluster = new DocumentCluster(null, (List)docAssignments.get(i), false);
            PostingsProcessingUtils.summarize(cluster, this.reader, this.summaryPruneRatio);
            clusters.add(cluster);
        }
        return clusters;
    }

    @Generated
    public RandomClusteringAlgorithm(float summaryPruneRatio, float clusterRatio, @NonNull SparseVectorReader reader) {
        Objects.requireNonNull(reader, "reader is marked non-null but is null");
        this.summaryPruneRatio = summaryPruneRatio;
        this.clusterRatio = clusterRatio;
        this.reader = reader;
    }
}

