package cz.cas.mbu.genexpi.compute;

import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDevice;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLException;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLMem;
import com.nativelibs4java.opencl.CLPlatform;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.LocalSize;
import com.nativelibs4java.util.IOUtils;
import java.io.IOException;
import java.lang.Number;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.bridj.Pointer;

/* loaded from: input_file:genexpi-compute-1.3.0.jar:cz/cas/mbu/genexpi/compute/GNCompute.class */
public class GNCompute<NUMBER_TYPE extends Number> {
    private final Class<NUMBER_TYPE> elementClass;
    private final CLContext context;
    private final InferenceModel model;
    private final EMethod method;
    private final EErrorFunction errorFunction;
    private final ELossFunction lossFunction;
    private final boolean useCustomTimeStep;
    private final float customTimeStep;
    private final CLProgram program;
    private final CLKernel kernel;
    private boolean verbose = true;
    private static /* synthetic */ int[] $SWITCH_TABLE$cz$cas$mbu$genexpi$compute$RegulationType;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:genexpi-compute-1.3.0.jar:cz/cas/mbu/genexpi/compute/GNCompute$OutputPointers.class */
    public class OutputPointers {
        CLBuffer<NUMBER_TYPE> optimizedParams;
        CLBuffer<NUMBER_TYPE> errors;

        public OutputPointers(CLBuffer<NUMBER_TYPE> cLBuffer, CLBuffer<NUMBER_TYPE> cLBuffer2) {
            this.optimizedParams = cLBuffer;
            this.errors = cLBuffer2;
        }

        public CLBuffer<NUMBER_TYPE> getOptimizedParams() {
            return this.optimizedParams;
        }

        public CLBuffer<NUMBER_TYPE> getErrors() {
            return this.errors;
        }

        public long getByteCount() {
            return this.optimizedParams.getByteCount() + this.errors.getByteCount();
        }
    }

    public static CLDevice getBestDevice() {
        CLDevice bestDevice = JavaCL.getBestDevice(CLPlatform.DeviceFeature.OutOfOrderQueueSupport, CLPlatform.DeviceFeature.GPU, CLPlatform.DeviceFeature.MaxComputeUnits);
        if (bestDevice != null) {
            return bestDevice;
        }
        CLDevice bestDevice2 = JavaCL.getBestDevice(CLPlatform.DeviceFeature.OutOfOrderQueueSupport, CLPlatform.DeviceFeature.MaxComputeUnits);
        return bestDevice2 != null ? bestDevice2 : JavaCL.getBestDevice();
    }

    public static CLContext getBestContext() {
        CLDevice bestDevice = getBestDevice();
        return bestDevice.getPlatform().createContext(null, bestDevice);
    }

    public GNCompute(Class<NUMBER_TYPE> cls, CLContext cLContext, InferenceModel inferenceModel, EMethod eMethod, EErrorFunction eErrorFunction, ELossFunction eLossFunction, boolean z, Float f) throws IOException {
        this.elementClass = cls;
        this.context = cLContext;
        this.model = inferenceModel;
        this.method = eMethod;
        this.errorFunction = eErrorFunction;
        this.lossFunction = eLossFunction;
        this.useCustomTimeStep = z;
        if (z) {
            this.customTimeStep = f.floatValue();
        } else {
            this.customTimeStep = Float.NaN;
        }
        String kernelName = inferenceModel.getKernelName(eMethod.getKernelBaseName());
        String readText = IOUtils.readText(GNCompute.class.getResource("Definitions.clh"));
        this.program = cLContext.createProgram(String.valueOf(readText) + IOUtils.readText(GNCompute.class.getResource("Utils.cl")) + IOUtils.readText(GNCompute.class.getResource("XorShift1024.cl")) + IOUtils.readText(GNCompute.class.getResource(inferenceModel.getModelSource())) + IOUtils.readText(GNCompute.class.getResource(eMethod.getMethodSource())));
        if (eErrorFunction != null) {
            this.program.defineMacro(eErrorFunction.getMacro(), 1);
        }
        this.program.defineMacro(eLossFunction.getMacro(), 1);
        this.program.defineMacro("GRADIENT_UPDATE", "CTSW_OFF");
        if (z) {
            this.program.defineMacro("CUSTOM_TIME_STEP", Float.toString(f.floatValue()));
        }
        if (inferenceModel.getAdditionalDefines() != null) {
            for (int i = 0; i < inferenceModel.getAdditionalDefines().length; i++) {
                this.program.defineMacro(inferenceModel.getAdditionalDefines()[i][0], inferenceModel.getAdditionalDefines()[i][1]);
            }
        }
        this.kernel = this.program.createKernel(kernelName, new Object[0]);
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    private CLEvent[] executeKernel(CLQueue cLQueue, int i, int i2, boolean z) {
        int max = z ? Math.max(cLQueue.getDevice().getMaxComputeUnits() - 1, 1) : 128;
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= i) {
                CLEvent[] cLEventArr = new CLEvent[arrayList.size()];
                arrayList.toArray(cLEventArr);
                return cLEventArr;
            }
            CLEvent enqueueNDRange = this.kernel.enqueueNDRange(cLQueue, new long[]{i4}, new long[]{i - i4 > max ? max : r0, i2}, new long[]{1, i2}, new CLEvent[0]);
            if (z) {
                enqueueNDRange.waitFor();
            } else {
                arrayList.add(enqueueNDRange);
            }
            i3 = i4 + max;
        }
    }

    private long prepareXorShiftParameters(List<Object> list, int i, int i2) {
        Pointer<Long> order = Pointer.allocateLongs(i * i2 * 16).order(this.context.getByteOrder());
        XorShift1024 xorShift1024 = new XorShift1024();
        xorShift1024.InitFromSecureRandomAndSplitMix();
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                xorShift1024.Jump();
                for (int i5 = 0; i5 < 16; i5++) {
                    order.set((((i5 * i) + i3) * i2) + i4, Long.valueOf(xorShift1024.GetState()[i5]));
                }
            }
        }
        CLBuffer<Long> createLongBuffer = this.context.createLongBuffer(CLMem.Usage.InputOutput, order);
        list.add(createLongBuffer);
        CLBuffer<Integer> createIntBuffer = this.context.createIntBuffer(CLMem.Usage.InputOutput, i * i2);
        list.add(createIntBuffer);
        return createLongBuffer.getByteCount() + createIntBuffer.getByteCount();
    }

    private long prepareBaseParameters(List<Object> list, List<GeneProfile<NUMBER_TYPE>> list2, int[] iArr) {
        ByteOrder byteOrder = this.context.getByteOrder();
        int size = list2.get(0).getProfile().size();
        Pointer order = Pointer.allocateArray((Class) this.elementClass, size * list2.size()).order(byteOrder);
        boolean z = false;
        for (int i = 0; i < list2.size(); i++) {
            List<NUMBER_TYPE> profile = list2.get(i).getProfile();
            for (int i2 = 0; i2 < size; i2++) {
                NUMBER_TYPE number_type = profile.get(i2);
                order.set((i2 * list2.size()) + i, number_type);
                if ((number_type.doubleValue() < 0.0d || Double.isInfinite(number_type.doubleValue()) || Double.isNaN(number_type.doubleValue())) && !z) {
                    System.out.println("Warning: some gene values are negative, infinity or NaN (e.g. for gene `" + list2.get(i).getName() + "` and time " + i2 + ")");
                    z = true;
                }
            }
        }
        CLBuffer createBuffer = this.context.createBuffer(CLMem.Usage.Input, order);
        long byteCount = 0 + createBuffer.getByteCount();
        list.add(createBuffer);
        long prepareProfileIDParameter = byteCount + prepareProfileIDParameter(list, iArr);
        list.add(LocalSize.ofFloatArray(size));
        list.add(Integer.valueOf(list2.size()));
        list.add(Integer.valueOf(size));
        return prepareProfileIDParameter;
    }

    private long prepareProfileIDParameter(List<Object> list, int[] iArr) {
        Pointer<Integer> order = Pointer.allocateInts(iArr.length).order(this.context.getByteOrder());
        order.setInts(iArr);
        CLBuffer<Integer> createIntBuffer = this.context.createIntBuffer(CLMem.Usage.Input, order);
        list.add(createIntBuffer);
        return createIntBuffer.getByteCount();
    }

    private long prepareWeightConstraintsParameter(List<Object> list, int[] iArr) {
        if (iArr == null) {
            list.add(null);
            return 0L;
        }
        Pointer<Integer> order = Pointer.allocateInts(iArr.length).order(this.context.getByteOrder());
        order.setInts(iArr);
        CLBuffer<Integer> createIntBuffer = this.context.createIntBuffer(CLMem.Usage.Input, order);
        list.add(createIntBuffer);
        return createIntBuffer.getByteCount();
    }

    private GNCompute<NUMBER_TYPE>.OutputPointers prepareOutputParameters(List<Object> list, int i, int i2) {
        CLBuffer createBuffer = this.context.createBuffer(CLMem.Usage.InputOutput, this.elementClass, i * i2 * this.model.getNumParams());
        CLBuffer createBuffer2 = this.context.createBuffer(CLMem.Usage.InputOutput, this.elementClass, i * i2);
        list.add(createBuffer);
        list.add(createBuffer2);
        return new OutputPointers(createBuffer, createBuffer2);
    }

    private List<InferenceResult> gatherInferenceResults(CLQueue cLQueue, GNCompute<NUMBER_TYPE>.OutputPointers outputPointers, CLEvent[] cLEventArr, int i, int i2, boolean z) {
        try {
            Pointer read = outputPointers.getOptimizedParams().read(cLQueue, cLEventArr);
            Pointer read2 = outputPointers.getErrors().read(cLQueue, cLEventArr);
            ArrayList arrayList = new ArrayList(i);
            int i3 = 0;
            for (int i4 = 0; i4 < i; i4++) {
                double d = Double.POSITIVE_INFINITY;
                int i5 = -1;
                for (int i6 = 0; i6 < i2; i6++) {
                    Number number = (Number) read2.get((i4 * i2) + i6);
                    if (Float.isNaN(number.floatValue())) {
                        i3++;
                    }
                    if (number.doubleValue() < d) {
                        d = number.doubleValue();
                        i5 = i6;
                    }
                }
                double[] dArr = new double[this.model.getNumParams()];
                for (int i7 = 0; i7 < this.model.getNumParams(); i7++) {
                    if (i5 < 0) {
                        dArr[i7] = Double.NaN;
                    } else {
                        dArr[i7] = ((Number) read.get((((i7 * i) + i4) * i2) + i5)).doubleValue();
                    }
                }
                arrayList.add(new InferenceResult(dArr, d));
            }
            if (i3 > 0) {
                System.out.println("Encountered " + i3 + " NaNs out of " + (i * i2) + " runs.");
            }
            return arrayList;
        } catch (CLException.OutOfResources e) {
            if (!Arrays.stream(this.context.getDevices()).anyMatch(cLDevice -> {
                return cLDevice.getType().contains(CLDevice.Type.GPU);
            }) || z) {
                throw e;
            }
            throw new SuspectGPUResetByOSException("The computation was not succesful, a possible cause is a reset by the OS.\nIf your are running computations on the same GPU the runs your main display, consider preventing full occupation of the GPU or running on a CPU instead.", e);
        }
    }

    private int regulationTypeToInt(RegulationType regulationType) {
        switch ($SWITCH_TABLE$cz$cas$mbu$genexpi$compute$RegulationType()[regulationType.ordinal()]) {
            case 1:
                return 0;
            case 2:
                return 1;
            case 3:
                return -1;
            default:
                throw new IllegalStateException("Unrecognized regulation type: " + regulationType);
        }
    }

    public List<InferenceResult> computeAdditiveRegulation(List<GeneProfile<NUMBER_TYPE>> list, List<AdditiveRegulationInferenceTask> list2, int i, int i2, float f, boolean z) throws IOException {
        CLQueue createDefaultQueue;
        if (list2.isEmpty()) {
            return Collections.EMPTY_LIST;
        }
        long nanoTime = System.nanoTime();
        try {
            createDefaultQueue = this.context.createDefaultOutOfOrderQueue();
        } catch (CLException e) {
            if (this.verbose) {
                System.out.println("Could not create out-of-order queue. Using default queue.");
            }
            createDefaultQueue = this.context.createDefaultQueue(new CLDevice.QueueProperties[0]);
        }
        int size = list2.size();
        int[] iArr = new int[size * i];
        int[] iArr2 = new int[size];
        int[] iArr3 = new int[size * i];
        for (int i3 = 0; i3 < size; i3++) {
            iArr2[i3] = list2.get(i3).getTargetID();
            int[] regulatorIDs = list2.get(i3).getRegulatorIDs();
            if (i != regulatorIDs.length) {
                throw new GNException("Inconsistent regulator numbers");
            }
            for (int i4 = 0; i4 < i; i4++) {
                iArr[(i3 * i) + i4] = regulatorIDs[i4];
                iArr3[(i3 * i) + i4] = regulationTypeToInt(list2.get(i3).getRegulationTypes()[i4]);
            }
        }
        ArrayList arrayList = new ArrayList();
        int size2 = list.get(0).getProfile().size();
        long prepareXorShiftParameters = 0 + prepareXorShiftParameters(arrayList, size, i2) + prepareBaseParameters(arrayList, list, iArr2);
        arrayList.add(Integer.valueOf(size));
        arrayList.add(Integer.valueOf(i2));
        long prepareProfileIDParameter = prepareXorShiftParameters + prepareProfileIDParameter(arrayList, iArr) + prepareWeightConstraintsParameter(arrayList, iArr3);
        arrayList.add(LocalSize.ofFloatArray(size2 * i));
        arrayList.add(LocalSize.ofFloatArray(i + 1));
        arrayList.add(Float.valueOf(f));
        arrayList.add(LocalSize.ofIntArray(i));
        GNCompute<NUMBER_TYPE>.OutputPointers prepareOutputParameters = prepareOutputParameters(arrayList, size, i2);
        long byteCount = (prepareProfileIDParameter + prepareOutputParameters.getByteCount()) / 1048576;
        if (this.verbose) {
            System.out.println("Allocating " + byteCount + "MB.");
        }
        this.kernel.setArgs(arrayList.toArray());
        long nanoTime2 = System.nanoTime();
        float f2 = ((float) (nanoTime2 - nanoTime)) / 1000000.0f;
        if (this.verbose) {
            System.out.println("Preparation took: " + f2 + " ms.");
        }
        List<InferenceResult> gatherInferenceResults = gatherInferenceResults(createDefaultQueue, prepareOutputParameters, executeKernel(createDefaultQueue, size, i2, z), size, i2, z);
        createDefaultQueue.finish();
        float nanoTime3 = ((float) (System.nanoTime() - nanoTime2)) / 1000000.0f;
        if (this.verbose) {
            System.out.println("Computation took: " + nanoTime3 + " ms.");
        }
        return gatherInferenceResults;
    }

    public List<InferenceResult> computeNoRegulator(List<GeneProfile<NUMBER_TYPE>> list, List<NoRegulatorInferenceTask> list2, int i, boolean z) throws IOException {
        CLQueue createDefaultQueue;
        if (list2.isEmpty()) {
            return Collections.EMPTY_LIST;
        }
        long nanoTime = System.nanoTime();
        try {
            createDefaultQueue = this.context.createDefaultOutOfOrderQueue();
        } catch (CLException e) {
            if (this.verbose) {
                System.out.println("Could not create out-of-order queue. Using default queue.");
            }
            createDefaultQueue = this.context.createDefaultQueue(new CLDevice.QueueProperties[0]);
        }
        int size = list2.size();
        int[] iArr = new int[size];
        for (int i2 = 0; i2 < size; i2++) {
            iArr[i2] = list2.get(i2).getTargetID();
        }
        ArrayList arrayList = new ArrayList();
        long prepareXorShiftParameters = 0 + prepareXorShiftParameters(arrayList, size, i) + prepareBaseParameters(arrayList, list, iArr);
        arrayList.add(Integer.valueOf(size));
        arrayList.add(Integer.valueOf(i));
        GNCompute<NUMBER_TYPE>.OutputPointers prepareOutputParameters = prepareOutputParameters(arrayList, size, i);
        long byteCount = (prepareXorShiftParameters + prepareOutputParameters.getByteCount()) / 1048576;
        if (this.verbose) {
            System.out.println("Allocating " + byteCount + "MB.");
        }
        this.kernel.setArgs(arrayList.toArray());
        long nanoTime2 = System.nanoTime();
        float f = ((float) (nanoTime2 - nanoTime)) / 1000000.0f;
        if (this.verbose) {
            System.out.println("Preparation took: " + f + " ms.");
        }
        List<InferenceResult> gatherInferenceResults = gatherInferenceResults(createDefaultQueue, prepareOutputParameters, executeKernel(createDefaultQueue, size, i, z), size, i, z);
        createDefaultQueue.finish();
        float nanoTime3 = ((float) (System.nanoTime() - nanoTime2)) / 1000000.0f;
        if (this.verbose) {
            System.out.println("Computation took: " + nanoTime3 + " ms.");
        }
        return gatherInferenceResults;
    }

    public InferenceModel getModel() {
        return this.model;
    }

    public EMethod getMethod() {
        return this.method;
    }

    public EErrorFunction getErrorFunction() {
        return this.errorFunction;
    }

    public ELossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean isUseCustomTimeStep() {
        return this.useCustomTimeStep;
    }

    public float getCustomTimeStep() {
        return this.customTimeStep;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$cz$cas$mbu$genexpi$compute$RegulationType() {
        int[] iArr = $SWITCH_TABLE$cz$cas$mbu$genexpi$compute$RegulationType;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[RegulationType.valuesCustom().length];
        try {
            iArr2[RegulationType.All.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[RegulationType.NegativeOnly.ordinal()] = 3;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[RegulationType.PositiveOnly.ordinal()] = 2;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$cz$cas$mbu$genexpi$compute$RegulationType = iArr2;
        return iArr2;
    }
}
