/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.fedplanner;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class FederatedPlanCostEstimator {
    private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0;
    private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0;

    public static void computeFederatedPlanCost(FederatedMemoTable.FedPlan currentPlan, FederatedMemoTable memoTable) {
        double totalCost;
        Hop currentHop = currentPlan.getHopRef();
        if (currentPlan.getSelfCost() == 0.0) {
            totalCost = FederatedPlanCostEstimator.computeCurrentCost(currentHop);
            currentPlan.setSelfCost(totalCost);
            currentPlan.setNetTransferCost(FederatedPlanCostEstimator.computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
        } else {
            totalCost = currentPlan.getSelfCost();
        }
        for (Pair<Long, FEDInstruction.FederatedOutput> childPlanPair : currentPlan.getChildFedPlans()) {
            FederatedMemoTable.FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair);
            totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType());
        }
        currentPlan.setTotalCost(totalCost);
    }

    public static LinkedHashMap<FederatedMemoTable.FedPlan, Boolean> resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long, List<FederatedMemoTable.FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) {
        LinkedHashMap<FederatedMemoTable.FedPlan, Boolean> resolvedFedPlanLinkedMap = new LinkedHashMap<FederatedMemoTable.FedPlan, Boolean>();
        for (Map.Entry<Long, List<FederatedMemoTable.FedPlan>> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) {
            FEDInstruction.FederatedOutput optimalFedOutType;
            long conflictHopID = conflictFedPlanPair.getKey();
            List<FederatedMemoTable.FedPlan> conflictParentFedPlans = conflictFedPlanPair.getValue();
            FederatedMemoTable.FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FEDInstruction.FederatedOutput.LOUT);
            FederatedMemoTable.FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FEDInstruction.FederatedOutput.FOUT);
            double lOutAdditionalCost = 0.0;
            double fOutAdditionalCost = 0.0;
            boolean isLOutNetTransfer = false;
            boolean isFOutNetTransfer = false;
            for (FederatedMemoTable.FedPlan conflictParentFedPlan : conflictParentFedPlans) {
                Pair cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream().filter(pair -> ((Long)pair.getLeft()).equals(conflictHopID)).findFirst().orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID));
                if (cacluatedConflictPlanPair.getRight() == FEDInstruction.FederatedOutput.LOUT) {
                    fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost();
                    if (conflictParentFedPlan.getFedOutType() == FEDInstruction.FederatedOutput.LOUT) {
                        isFOutNetTransfer = true;
                        continue;
                    }
                    isLOutNetTransfer = true;
                    lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
                    fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
                    continue;
                }
                lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost();
                if (conflictParentFedPlan.getFedOutType() == FEDInstruction.FederatedOutput.FOUT) {
                    isLOutNetTransfer = true;
                    continue;
                }
                isFOutNetTransfer = true;
                lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
                fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
            }
            if (isLOutNetTransfer) {
                lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost();
            }
            if (isFOutNetTransfer) {
                fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost();
            }
            if (lOutAdditionalCost <= fOutAdditionalCost) {
                optimalFedOutType = FEDInstruction.FederatedOutput.LOUT;
                cumulativeAdditionalCost[0] = cumulativeAdditionalCost[0] + lOutAdditionalCost;
                resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true);
            } else {
                optimalFedOutType = FEDInstruction.FederatedOutput.FOUT;
                cumulativeAdditionalCost[0] = cumulativeAdditionalCost[0] + fOutAdditionalCost;
                resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true);
            }
            block2: for (FederatedMemoTable.FedPlan conflictParentFedPlan : conflictParentFedPlans) {
                for (Pair<Long, FEDInstruction.FederatedOutput> childPlanPair : conflictParentFedPlan.getChildFedPlans()) {
                    if ((Long)childPlanPair.getLeft() != conflictHopID || childPlanPair.getRight() == optimalFedOutType) continue;
                    int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair);
                    conflictParentFedPlan.getChildFedPlans().set(index, (Pair<Long, FEDInstruction.FederatedOutput>)Pair.of((Object)((Long)childPlanPair.getLeft()), (Object)((Object)optimalFedOutType)));
                    continue block2;
                }
            }
        }
        return resolvedFedPlanLinkedMap;
    }

    private static double computeCurrentCost(Hop currentHop) {
        double computeCost = ComputeCost.getHOPComputeCost(currentHop);
        double inputAccessCost = FederatedPlanCostEstimator.computeHopMemoryAccessCost(currentHop.getInputMemEstimate());
        double ouputAccessCost = FederatedPlanCostEstimator.computeHopMemoryAccessCost(currentHop.getOutputMemEstimate());
        return Math.max(computeCost, inputAccessCost) + ouputAccessCost;
    }

    private static double computeHopMemoryAccessCost(double memSize) {
        return memSize / 1048576.0 / 25000.0;
    }

    private static double computeHopNetworkAccessCost(double memSize) {
        return memSize / 1048576.0 / 125.0;
    }
}

