/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu.context;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import jcuda.CudaException;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;

public class JCudaKernels {
    private static final String ptxFileName = "/cuda/kernels/SystemDS.ptx";
    private HashMap<String, CUfunction> kernels = new HashMap();
    private CUmodule module = new CUmodule();

    JCudaKernels() {
        JCudaKernels.checkResult(JCudaDriver.cuModuleLoadDataEx((CUmodule)this.module, (Pointer)JCudaKernels.initKernels(ptxFileName), (int)0, (int[])new int[0], (Pointer)Pointer.to((int[])new int[0])));
    }

    public void launchKernel(String name, ExecutionConfig config, Object ... arguments) {
        CUfunction function = this.kernels.get(name = (String)name + LibMatrixCUDA.customKernelSuffix);
        if (function == null) {
            function = new CUfunction();
            try {
                JCudaKernels.checkResult(JCudaDriver.cuModuleGetFunction((CUfunction)function, (CUmodule)this.module, (String)name));
            }
            catch (CudaException e) {
                throw new DMLRuntimeException("Error finding the custom kernel:" + (String)name, (Exception)((Object)e));
            }
        }
        Pointer[] kernelParams = new Pointer[arguments.length];
        for (int i = 0; i < arguments.length; ++i) {
            if (arguments[i] == null) {
                throw new DMLRuntimeException("The argument to the kernel cannot be null.");
            }
            if (arguments[i] instanceof Pointer) {
                kernelParams[i] = Pointer.to((NativePointerObject[])new NativePointerObject[]{(Pointer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Integer) {
                kernelParams[i] = Pointer.to((int[])new int[]{(Integer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Double) {
                kernelParams[i] = Pointer.to((double[])new double[]{(Double)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Long) {
                kernelParams[i] = Pointer.to((long[])new long[]{(Long)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Float) {
                kernelParams[i] = Pointer.to((float[])new float[]{((Float)arguments[i]).floatValue()});
                continue;
            }
            throw new DMLRuntimeException("The argument of type " + arguments[i].getClass() + " is not supported.");
        }
        JCudaKernels.checkResult(JCudaDriver.cuLaunchKernel((CUfunction)function, (int)config.gridDimX, (int)config.gridDimY, (int)config.gridDimZ, (int)config.blockDimX, (int)config.blockDimY, (int)config.blockDimZ, (int)config.sharedMemBytes, (CUstream)config.stream, (Pointer)Pointer.to((NativePointerObject[])kernelParams), null));
        if (DMLScript.SYNCHRONIZE_GPU) {
            JCuda.cudaDeviceSynchronize();
        }
    }

    public static void checkResult(int cuResult) {
        if (cuResult != 0) {
            throw new DMLRuntimeException(CUresult.stringFor((int)cuResult));
        }
    }

    /*
     * Loose catch block
     */
    private static Pointer initKernels(String ptxFileName) {
        InputStream in;
        ByteArrayOutputStream out;
        block12: {
            Pointer pointer;
            block13: {
                int read;
                out = null;
                in = JCudaKernels.class.getResourceAsStream(ptxFileName);
                if (in == null) break block12;
                out = new ByteArrayOutputStream();
                byte[] buffer = new byte[8192];
                while ((read = in.read(buffer)) != -1) {
                    out.write(buffer, 0, read);
                }
                out.write(0);
                out.flush();
                pointer = Pointer.to((byte[])out.toByteArray());
                if (in == null) break block13;
                in.close();
            }
            IOUtilFunctions.closeSilently(out);
            return pointer;
        }
        try {
            try {
                throw new DMLRuntimeException("The input file " + ptxFileName + " not found. (Hint: Please compile SystemDS using -DenableGPU=true flag. Example: mvn package -DenableGPU=true).");
                {
                    catch (Throwable throwable) {
                        if (in != null) {
                            try {
                                in.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                }
            }
            catch (IOException e) {
                throw new DMLRuntimeException("Could not initialize the kernels", e);
            }
        }
        catch (Throwable throwable) {
            IOUtilFunctions.closeSilently(out);
            throw throwable;
        }
    }
}

