package com.aparapi.internal.kernel;

import com.aparapi.Config;
import com.aparapi.Kernel;
import com.aparapi.device.Device;
import com.aparapi.device.JavaDevice;
import com.aparapi.device.OpenCLDevice;
import com.aparapi.internal.util.Reflection;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:aparapi-2.0.0.jar:com/aparapi/internal/kernel/KernelManager.class */
public class KernelManager {
    private static KernelManager INSTANCE = new KernelManager();
    private LinkedHashMap<Class<? extends Kernel>, PreferencesWrapper> preferences = new LinkedHashMap<>();
    private LinkedHashMap<Class<? extends Kernel>, KernelProfile> profiles = new LinkedHashMap<>();
    private LinkedHashMap<Class<? extends Kernel>, Kernel> sharedInstances = new LinkedHashMap<>();
    private KernelPreferences defaultPreferences;

    /* loaded from: input_file:aparapi-2.0.0.jar:com/aparapi/internal/kernel/KernelManager$DeprecatedMethods.class */
    public static class DeprecatedMethods {
        @Deprecated
        public static Device firstDevice(Device.TYPE type) {
            for (Device device : KernelManager.instance().getDefaultPreferences().getPreferredDevices(null)) {
                if (device.getType() == type) {
                    return device;
                }
            }
            return null;
        }

        @Deprecated
        public static Device bestGPU() {
            return firstDevice(Device.TYPE.GPU);
        }

        @Deprecated
        public static Device bestACC() {
            return firstDevice(Device.TYPE.ACC);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KernelManager() {
        setup();
    }

    protected void setup() {
        this.defaultPreferences = createDefaultPreferences();
    }

    public static KernelManager instance() {
        return INSTANCE;
    }

    public static void setKernelManager(KernelManager kernelManager) {
        INSTANCE = kernelManager;
    }

    public static <T extends Kernel> T sharedKernelInstance(Class<T> cls) {
        return (T) instance().getSharedKernelInstance(cls);
    }

    public void reportDeviceUsage(StringBuilder sb, boolean z) {
        sb.append("Device Usage by Kernel Subclass");
        if (z) {
            sb.append(" (showing mean elapsed times in milliseconds)");
        }
        sb.append("\n\n");
        for (PreferencesWrapper preferencesWrapper : this.preferences.values()) {
            KernelPreferences preferences = preferencesWrapper.getPreferences();
            Class<? extends Kernel> kernelClass = preferencesWrapper.getKernelClass();
            KernelProfile kernelProfile = z ? this.profiles.get(kernelClass) : null;
            sb.append(kernelClass.getName()).append(":\n\tusing ").append(preferences.getPreferredDevice(null).getShortDescription());
            List<Device> failedDevices = preferences.getFailedDevices();
            if (failedDevices.size() > 0) {
                sb.append(", failed devices = ");
                for (int i = 0; i < failedDevices.size(); i++) {
                    sb.append(failedDevices.get(i).getShortDescription());
                    if (i < failedDevices.size() - 1) {
                        sb.append(" | ");
                    }
                }
            }
            if (kernelProfile != null) {
                sb.append(StringUtils.LF);
                int i2 = 0;
                for (KernelDeviceProfile kernelDeviceProfile : kernelProfile.getDeviceProfiles()) {
                    if (i2 == 0) {
                        sb.append(KernelDeviceProfile.getTableHeader()).append(StringUtils.LF);
                    }
                    sb.append(kernelDeviceProfile.getAverageAsTableRow()).append(StringUtils.LF);
                    i2++;
                }
            }
            sb.append(StringUtils.LF);
        }
    }

    public void reportProfilingSummary(StringBuilder sb) {
        sb.append("\nProfiles by Kernel Subclass (mean elapsed times in milliseconds)\n\n");
        sb.append(KernelDeviceProfile.getTableHeader()).append(StringUtils.LF);
        for (Class<? extends Kernel> cls : this.profiles.keySet()) {
            String str = "----------------- [[ " + Reflection.getSimpleName(cls) + " ]] ";
            sb.append(str);
            int length = 132 - str.length();
            for (int i = 0; i < length; i++) {
                sb.append('-');
            }
            sb.append(StringUtils.LF);
            Iterator<KernelDeviceProfile> it = this.profiles.get(cls).getDeviceProfiles().iterator();
            while (it.hasNext()) {
                sb.append(it.next().getAverageAsTableRow()).append(StringUtils.LF);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public KernelPreferences getPreferences(Kernel kernel) {
        KernelPreferences preferences;
        KernelPreferences kernelPreferences;
        synchronized (this.preferences) {
            if (this.preferences.get(kernel.getClass()) == null) {
                preferences = new KernelPreferences(this, kernel.getClass());
                this.preferences.put(kernel.getClass(), new PreferencesWrapper(kernel.getClass(), preferences));
            } else {
                preferences = this.preferences.get(kernel.getClass()).getPreferences();
            }
            kernelPreferences = preferences;
        }
        return kernelPreferences;
    }

    public void setPreferredDevices(Kernel kernel, LinkedHashSet<Device> linkedHashSet) {
        getPreferences(kernel).setPreferredDevices(linkedHashSet);
    }

    public KernelPreferences getDefaultPreferences() {
        return this.defaultPreferences;
    }

    public void setDefaultPreferredDevices(LinkedHashSet<Device> linkedHashSet) {
        this.defaultPreferences.setPreferredDevices(linkedHashSet);
    }

    protected KernelPreferences createDefaultPreferences() {
        KernelPreferences kernelPreferences = new KernelPreferences(this, null);
        kernelPreferences.setPreferredDevices(createDefaultPreferredDevices());
        return kernelPreferences;
    }

    private <T extends Kernel> T getSharedKernelInstance(Class<T> cls) {
        T t;
        synchronized (this.sharedInstances) {
            Kernel kernel = this.sharedInstances.get(cls);
            if (kernel == null) {
                try {
                    Constructor<T> constructor = cls.getConstructor(new Class[0]);
                    constructor.setAccessible(true);
                    kernel = constructor.newInstance(new Object[0]);
                    this.sharedInstances.put(cls, kernel);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            t = (T) kernel;
        }
        return t;
    }

    protected LinkedHashSet<Device> createDefaultPreferredDevices() {
        LinkedHashSet<Device> linkedHashSet = new LinkedHashSet<>();
        List<OpenCLDevice> listDevices = OpenCLDevice.listDevices(Device.TYPE.ACC);
        List<OpenCLDevice> listDevices2 = OpenCLDevice.listDevices(Device.TYPE.GPU);
        List<OpenCLDevice> listDevices3 = OpenCLDevice.listDevices(Device.TYPE.CPU);
        Collections.sort(listDevices, getDefaultAcceleratorComparator());
        Collections.sort(listDevices2, getDefaultGPUComparator());
        Iterator<Device.TYPE> it = getPreferredDeviceTypes().iterator();
        while (it.hasNext()) {
            switch (it.next()) {
                case UNKNOWN:
                    throw new AssertionError("UNKNOWN device type not supported");
                case GPU:
                    linkedHashSet.addAll(listDevices2);
                    break;
                case CPU:
                    linkedHashSet.addAll(listDevices3);
                    break;
                case JTP:
                    linkedHashSet.add(JavaDevice.THREAD_POOL);
                    break;
                case SEQ:
                    linkedHashSet.add(JavaDevice.SEQUENTIAL);
                    break;
                case ACC:
                    linkedHashSet.addAll(listDevices);
                    break;
                case ALT:
                    linkedHashSet.add(JavaDevice.ALTERNATIVE_ALGORITHM);
                    break;
            }
        }
        return linkedHashSet;
    }

    protected List<Device.TYPE> getPreferredDeviceTypes() {
        return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU, Device.TYPE.ALT, Device.TYPE.JTP);
    }

    protected Comparator<OpenCLDevice> getDefaultAcceleratorComparator() {
        return new Comparator<OpenCLDevice>() { // from class: com.aparapi.internal.kernel.KernelManager.2
            @Override // java.util.Comparator
            public int compare(OpenCLDevice openCLDevice, OpenCLDevice openCLDevice2) {
                return openCLDevice2.getMaxComputeUnits() - openCLDevice.getMaxComputeUnits();
            }
        };
    }

    protected Comparator<OpenCLDevice> getDefaultGPUComparator() {
        return new Comparator<OpenCLDevice>() { // from class: com.aparapi.internal.kernel.KernelManager.3
            @Override // java.util.Comparator
            public int compare(OpenCLDevice openCLDevice, OpenCLDevice openCLDevice2) {
                return KernelManager.selectLhs(openCLDevice, openCLDevice2) ? -1 : 1;
            }
        };
    }

    public Device bestDevice() {
        return getDefaultPreferences().getPreferredDevice(null);
    }

    protected static boolean selectLhs(OpenCLDevice openCLDevice, OpenCLDevice openCLDevice2) {
        return (openCLDevice.getOpenCLPlatform().getVendor().toLowerCase().contains("nvidia") || openCLDevice2.getOpenCLPlatform().getVendor().toLowerCase().contains("nvidia")) ? selectLhsIfCUDA(openCLDevice, openCLDevice2) : openCLDevice.getMaxComputeUnits() > openCLDevice2.getMaxComputeUnits();
    }

    protected static boolean selectLhsIfCUDA(OpenCLDevice openCLDevice, OpenCLDevice openCLDevice2) {
        return openCLDevice.getType() != openCLDevice2.getType() ? selectLhsByType(openCLDevice.getType(), openCLDevice2.getType()) : openCLDevice.getMaxWorkGroupSize() == openCLDevice2.getMaxWorkGroupSize() ? openCLDevice.getGlobalMemSize() > openCLDevice2.getGlobalMemSize() : openCLDevice.getMaxWorkGroupSize() > openCLDevice2.getMaxWorkGroupSize();
    }

    private static boolean selectLhsByType(Device.TYPE type, Device.TYPE type2) {
        return type.rank < type2.rank;
    }

    public KernelProfile getProfile(Class<? extends Kernel> cls) {
        KernelProfile kernelProfile;
        synchronized (this.profiles) {
            KernelProfile kernelProfile2 = this.profiles.get(cls);
            if (kernelProfile2 == null) {
                kernelProfile2 = new KernelProfile(cls);
                this.profiles.put(cls, kernelProfile2);
            }
            kernelProfile = kernelProfile2;
        }
        return kernelProfile;
    }

    static {
        if (Config.dumpProfilesOnExit) {
            Runtime.getRuntime().addShutdownHook(new Thread() { // from class: com.aparapi.internal.kernel.KernelManager.1
                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    StringBuilder sb = new StringBuilder(2048);
                    KernelManager.instance().reportProfilingSummary(sb);
                    System.out.println(sb);
                }
            });
        }
    }
}
