/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tinkerpop.gremlin.process.traversal.step.filter;

import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.function.BinaryOperator;
import org.apache.tinkerpop.gremlin.process.computer.MemoryComputeKey;
import org.apache.tinkerpop.gremlin.process.traversal.Traversal;
import org.apache.tinkerpop.gremlin.process.traversal.Traverser;
import org.apache.tinkerpop.gremlin.process.traversal.lambda.ConstantTraversal;
import org.apache.tinkerpop.gremlin.process.traversal.step.ByModulating;
import org.apache.tinkerpop.gremlin.process.traversal.step.Seedable;
import org.apache.tinkerpop.gremlin.process.traversal.step.TraversalParent;
import org.apache.tinkerpop.gremlin.process.traversal.step.util.CollectingBarrierStep;
import org.apache.tinkerpop.gremlin.process.traversal.traverser.ProjectedTraverser;
import org.apache.tinkerpop.gremlin.process.traversal.traverser.TraverserRequirement;
import org.apache.tinkerpop.gremlin.process.traversal.traverser.util.TraverserSet;
import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalProduct;
import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalUtil;
import org.apache.tinkerpop.gremlin.structure.util.StringFactory;

public final class SampleGlobalStep<S>
extends CollectingBarrierStep<S>
implements TraversalParent,
ByModulating,
Seedable {
    private boolean detectedMutipleBy = false;
    private Traversal.Admin<S, Number> probabilityTraversal = new ConstantTraversal(1.0);
    private final int amountToSample;
    private final Random random = new Random();
    private final SampleBiOperator<S> traverserReducer = new SampleBiOperator();

    public SampleGlobalStep(Traversal.Admin traversal, int amountToSample) {
        super(traversal);
        this.amountToSample = amountToSample;
    }

    public int getAmountToSample() {
        return this.amountToSample;
    }

    @Override
    public void resetSeed(long seed) {
        this.random.setSeed(seed);
    }

    public List<Traversal.Admin<S, Number>> getLocalChildren() {
        return Collections.singletonList(this.probabilityTraversal);
    }

    @Override
    public void modulateBy(Traversal.Admin<?, ?> probabilityTraversal) {
        if (this.detectedMutipleBy) {
            throw new IllegalStateException("Sample step can only have one by modulator");
        }
        this.detectedMutipleBy = true;
        this.probabilityTraversal = this.integrateChild(probabilityTraversal);
    }

    @Override
    public void replaceLocalChild(Traversal.Admin<?, ?> oldTraversal, Traversal.Admin<?, ?> newTraversal) {
        if (null != this.probabilityTraversal && this.probabilityTraversal.equals(oldTraversal)) {
            this.detectedMutipleBy = true;
            this.probabilityTraversal = this.integrateChild(newTraversal);
        }
    }

    @Override
    public String toString() {
        return StringFactory.stepString(this, this.amountToSample, this.probabilityTraversal);
    }

    @Override
    public void processAllStarts() {
        while (this.starts.hasNext()) {
            this.createProjectedTraverser((Traverser.Admin<S>)this.starts.next()).ifPresent(this.traverserSet::add);
        }
    }

    @Override
    public void barrierConsumer(TraverserSet<S> traverserSet) {
        if (traverserSet.bulkSize() <= (long)this.amountToSample) {
            return;
        }
        double totalWeight = 0.0;
        for (Traverser.Admin<S> s : traverserSet) {
            totalWeight += ((Number)((ProjectedTraverser)s).getProjections().get(0)).doubleValue() * (double)s.bulk();
        }
        TraverserSet<S> sampledSet = this.traversal.getTraverserSetSupplier().get();
        int runningAmountToSample = 0;
        block1: while (runningAmountToSample < this.amountToSample) {
            boolean reSample = false;
            double runningTotalWeight = totalWeight;
            for (Traverser.Admin<S> s : traverserSet) {
                long sampleBulk = sampledSet.contains(s) ? sampledSet.get(s).bulk() : 0L;
                if (sampleBulk >= s.bulk()) continue;
                double currentWeight = ((Number)((ProjectedTraverser)s).getProjections().get(0)).doubleValue();
                int i = 0;
                while ((long)i < s.bulk() - sampleBulk) {
                    if (this.random.nextDouble() <= currentWeight / runningTotalWeight) {
                        Traverser.Admin<S> split = s.split();
                        split.setBulk(1L);
                        sampledSet.add(split);
                        ++runningAmountToSample;
                        reSample = true;
                        break;
                    }
                    runningTotalWeight -= currentWeight;
                    ++i;
                }
                if (!reSample && runningAmountToSample < this.amountToSample) continue;
                continue block1;
            }
        }
        traverserSet.clear();
        traverserSet.addAll(sampledSet);
    }

    @Override
    public MemoryComputeKey<TraverserSet<S>> getMemoryComputeKey() {
        return MemoryComputeKey.of(this.getId(), this.traverserReducer, false, true);
    }

    private Optional<ProjectedTraverser<S, Number>> createProjectedTraverser(Traverser.Admin<S> traverser) {
        TraversalProduct product = TraversalUtil.produce(traverser, this.probabilityTraversal);
        if (product.isProductive()) {
            Object o = product.get();
            if (!(o instanceof Number)) {
                throw new IllegalStateException(String.format("Traverser %s does not evaluate to a number with %s", traverser, this.probabilityTraversal));
            }
            return Optional.of(new ProjectedTraverser<S, Number>(traverser, Collections.singletonList((Number)product.get())));
        }
        return Optional.empty();
    }

    @Override
    public Set<TraverserRequirement> getRequirements() {
        return this.getSelfAndChildRequirements(TraverserRequirement.BULK);
    }

    @Override
    public SampleGlobalStep<S> clone() {
        SampleGlobalStep clone = (SampleGlobalStep)super.clone();
        clone.probabilityTraversal = this.probabilityTraversal.clone();
        clone.detectedMutipleBy = this.detectedMutipleBy;
        return clone;
    }

    @Override
    public void setTraversal(Traversal.Admin<?, ?> parentTraversal) {
        super.setTraversal(parentTraversal);
        this.integrateChild(this.probabilityTraversal);
    }

    @Override
    public int hashCode() {
        return super.hashCode() ^ this.amountToSample ^ this.probabilityTraversal.hashCode();
    }

    public static final class SampleBiOperator<S>
    implements BinaryOperator<TraverserSet<S>>,
    Serializable {
        @Override
        public TraverserSet<S> apply(TraverserSet<S> setA, TraverserSet<S> setB) {
            int maxLoops = -1;
            TraverserSet result = new TraverserSet();
            maxLoops = SampleBiOperator.processTraverserSet(setA, maxLoops, result);
            SampleBiOperator.processTraverserSet(setB, maxLoops, result);
            return result;
        }

        private static <S> int processTraverserSet(TraverserSet<S> traverserSet, int currentMaxLoops, TraverserSet<S> result) {
            int max = currentMaxLoops;
            for (Traverser.Admin<S> traverser : traverserSet) {
                int loops;
                int n = loops = traverser.getLoopNames().isEmpty() ? 0 : traverser.loops();
                if (loops > max) {
                    max = loops;
                    result.clear();
                    result.add(traverser);
                    continue;
                }
                if (loops != max) continue;
                result.add(traverser);
            }
            return max;
        }
    }
}

