/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;

public class ColumnEncoderComposite
extends ColumnEncoder {
    private static final long serialVersionUID = -8473768154646831882L;
    private List<ColumnEncoder> _columnEncoders = null;
    private FrameBlock _meta = null;

    public ColumnEncoderComposite() {
        super(-1);
    }

    public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, FrameBlock meta) {
        super(-1);
        if (columnEncoders.size() <= 0 || !columnEncoders.stream().allMatch(encoder -> encoder._colID == ((ColumnEncoder)columnEncoders.get((int)0))._colID)) {
            throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columIDs");
        }
        this._colID = columnEncoders.get((int)0)._colID;
        this._meta = meta;
        this._columnEncoders = columnEncoders;
    }

    public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders) {
        this(columnEncoders, null);
    }

    public ColumnEncoderComposite(ColumnEncoder columnEncoder) {
        super(columnEncoder._colID);
        this._columnEncoders = new ArrayList<ColumnEncoder>();
        this._columnEncoders.add(columnEncoder);
    }

    public List<ColumnEncoder> getEncoders() {
        return this._columnEncoders;
    }

    public <T extends ColumnEncoder> T getEncoder(Class<T> type) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (!columnEncoder.getClass().equals(type)) continue;
            return (T)((ColumnEncoder)type.cast(columnEncoder));
        }
        return null;
    }

    public boolean isEncoder(int colID, Class<?> type) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (!columnEncoder.getClass().equals(type) || columnEncoder._colID != colID) continue;
            return true;
        }
        return false;
    }

    @Override
    public void build(CacheBlock in) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.build(in);
        }
    }

    @Override
    public List<DependencyTask<?>> getApplyTasks(CacheBlock in, MatrixBlock out, int outputCol) {
        ArrayList<DependencyTask<?>> tasks = new ArrayList();
        ArrayList<Integer> sizes = new ArrayList<Integer>();
        for (int i = 0; i < this._columnEncoders.size(); ++i) {
            List<DependencyTask<?>> t = i == 0 ? this._columnEncoders.get(i).getApplyTasks(in, out, outputCol) : this._columnEncoders.get(i).getApplyTasks(out, out, outputCol);
            if (t == null) continue;
            sizes.add(t.size());
            tasks.addAll(t);
        }
        ArrayList<Object> dep = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
        int c = 0;
        for (int i = ((Integer)sizes.get(c)).intValue(); i < tasks.size(); i += ((Integer)sizes.get(++c)).intValue()) {
            for (int k = i; k < i + (Integer)sizes.get(c + 1); ++k) {
                dep.set(k, tasks.subList(i - 1, i));
            }
        }
        tasks = DependencyThreadPool.createDependencyTasks(tasks, dep);
        return tasks;
    }

    @Override
    protected ColumnEncoder.ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) {
        throw new NotImplementedException();
    }

    @Override
    public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
        ArrayList tasks = new ArrayList();
        HashMap<Integer[], Integer[]> depMap = null;
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            List<DependencyTask<?>> t = columnEncoder.getBuildTasks(in);
            if (t == null) continue;
            if (tasks.size() != 0) {
                depMap = depMap == null ? new HashMap<Integer[], Integer[]>() : depMap;
                depMap.put(new Integer[]{tasks.size(), tasks.size() + t.size()}, new Integer[]{tasks.size() - 1, tasks.size()});
            }
            tasks.addAll(t);
        }
        ArrayList<Object> dep = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
        DependencyThreadPool.createDependencyList(tasks, depMap, dep);
        if (this.hasEncoder(ColumnEncoderDummycode.class)) {
            tasks.add(DependencyThreadPool.createDependencyTask(new ColumnCompositeUpdateDCTask(this)));
            dep.add(tasks.subList(tasks.size() - 2, tasks.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(tasks, dep);
    }

    @Override
    public void prepareBuildPartial() {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.prepareBuildPartial();
        }
    }

    @Override
    public void buildPartial(FrameBlock in) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.buildPartial(in);
        }
    }

    @Override
    public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        try {
            for (int i = 0; i < this._columnEncoders.size(); ++i) {
                if (i == 0) {
                    this._columnEncoders.get(i).apply(in, out, outputCol, rowStart, blk);
                    continue;
                }
                this._columnEncoders.get(i).apply(out, out, outputCol, rowStart, blk);
            }
        }
        catch (Exception ex) {
            LOG.error((Object)("Failed to transform-apply frame with \n" + this));
            throw ex;
        }
        return out;
    }

    @Override
    protected double getCode(CacheBlock in, int row) {
        throw new DMLRuntimeException("CompositeEncoder does not have a Code");
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.N_A;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ColumnEncoderComposite that = (ColumnEncoderComposite)o;
        return this._columnEncoders.equals(that._columnEncoders) && Objects.equals(this._meta, that._meta);
    }

    public int hashCode() {
        return Objects.hash(this._columnEncoders, this._meta);
    }

    @Override
    public void mergeAt(ColumnEncoder other) {
        if (other instanceof ColumnEncoderComposite) {
            ColumnEncoderComposite otherComposite = (ColumnEncoderComposite)other;
            assert (otherComposite._colID == this._colID);
            for (ColumnEncoder otherEnc : otherComposite.getEncoders()) {
                this.addEncoder(otherEnc);
            }
        } else {
            this.addEncoder(other);
        }
        this.updateAllDCEncoders();
    }

    public void updateAllDCEncoders() {
        ColumnEncoderDummycode dc = this.getEncoder(ColumnEncoderDummycode.class);
        if (dc != null) {
            dc.updateDomainSizes(this._columnEncoders);
        }
    }

    public void addEncoder(ColumnEncoder other) {
        Object encoder = this.getEncoder(other.getClass());
        assert (this._colID == other._colID);
        if (encoder != null) {
            ((ColumnEncoder)encoder).mergeAt(other);
        } else {
            this._columnEncoders.add(other);
            this._columnEncoders.sort(null);
        }
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int colOffset) {
        for (ColumnEncoder enc : this._columnEncoders) {
            enc.updateIndexRanges(beginDims, endDims, colOffset);
        }
    }

    @Override
    public FrameBlock getMetaData(FrameBlock out) {
        if (this._meta != null) {
            return this._meta;
        }
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.getMetaData(out);
        }
        return out;
    }

    @Override
    public void initMetaData(FrameBlock out) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.initMetaData(out);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("CompositeEncoder(").append(this._columnEncoders.size()).append("):\n");
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            sb.append("-- ");
            sb.append(columnEncoder.getClass().getSimpleName());
            sb.append(": ");
            sb.append(columnEncoder._colID);
            sb.append("\n");
        }
        return sb.toString();
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeInt(this._columnEncoders.size());
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            out.writeInt(columnEncoder._colID);
            out.writeByte(EncoderFactory.getEncoderType(columnEncoder));
            columnEncoder.writeExternal(out);
        }
        out.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(out);
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        int encodersSize = in.readInt();
        this._columnEncoders = new ArrayList<ColumnEncoder>();
        for (int i = 0; i < encodersSize; ++i) {
            int colID = in.readInt();
            ColumnEncoder columnEncoder = EncoderFactory.createInstance(in.readByte());
            columnEncoder.readExternal(in);
            columnEncoder.setColID(colID);
            this._columnEncoders.add(columnEncoder);
        }
        if (in.readBoolean()) {
            FrameBlock meta = new FrameBlock();
            meta.readFields(in);
            this._meta = meta;
        }
    }

    public <T extends ColumnEncoder> boolean hasEncoder(Class<T> type) {
        return this._columnEncoders.stream().anyMatch(encoder -> encoder.getClass().equals(type));
    }

    @Override
    public void shiftCol(int columnOffset) {
        super.shiftCol(columnOffset);
        this._columnEncoders.forEach(e -> e.shiftCol(columnOffset));
    }

    @Override
    public Set<Integer> getSparseRowsWZeros() {
        return this._columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> {
            if (l == null) {
                return null;
            }
            return l.stream();
        }).collect(Collectors.toSet());
    }

    private static class ColumnCompositeUpdateDCTask
    implements Callable<Object> {
        private final ColumnEncoderComposite _encoder;

        protected ColumnCompositeUpdateDCTask(ColumnEncoderComposite encoder) {
            this._encoder = encoder;
        }

        @Override
        public Void call() throws Exception {
            this._encoder.updateAllDCEncoders();
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }
}

