package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.ModifiableCountDownLatch;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ID3.class */
public class ID3 implements Classifier {
    private static final long serialVersionUID = -8473683139353205898L;
    private CategoricalData predicting;
    private CategoricalData[] attributes;
    private ID3Node root;
    private ModifiableCountDownLatch latch;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ID3$ID3Node.class */
    public static class ID3Node {
        ID3Node[] children;
        CategoricalResults cr;
        int attributeId;

        private ID3Node() {
        }

        public ID3Node(int i, int i2) {
            this.cr = null;
            this.children = new ID3Node[i];
            this.attributeId = i2;
        }

        public ID3Node(CategoricalResults categoricalResults) {
            this.children = null;
            this.cr = categoricalResults;
        }

        public boolean isLeaf() {
            return this.cr != null;
        }

        public void setNode(int i, ID3Node iD3Node) {
            this.children[i] = iD3Node;
        }

        public ID3Node getNode(int i) {
            return this.children[i];
        }

        public int getAttributeId() {
            return this.attributeId;
        }

        public CategoricalResults getResult() {
            return this.cr;
        }

        public ID3Node copy() {
            ID3Node iD3Node = new ID3Node();
            iD3Node.cr = this.cr;
            iD3Node.attributeId = this.attributeId;
            if (this.children != null) {
                iD3Node.children = new ID3Node[this.children.length];
                for (int i = 0; i < this.children.length; i++) {
                    iD3Node.children[i] = this.children[i].copy();
                }
            }
            return iD3Node;
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return walkTree(this.root, dataPoint);
    }

    private static CategoricalResults walkTree(ID3Node iD3Node, DataPoint dataPoint) {
        return iD3Node.isLeaf() ? iD3Node.getResult() : walkTree(iD3Node.getNode(dataPoint.getCategoricalValue(iD3Node.getAttributeId())), dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (classificationDataSet.getNumNumericalVars() != 0) {
            throw new RuntimeException("ID3 only supports categorical data");
        }
        this.predicting = classificationDataSet.getPredicting();
        this.attributes = classificationDataSet.getCategories();
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        IntSet intSet = new IntSet(classificationDataSet.getNumCategoricalVars());
        for (int i = 0; i < classificationDataSet.getNumCategoricalVars(); i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        this.latch = new ModifiableCountDownLatch(1);
        this.root = buildTree(asDPPList, intSet, executorService);
        try {
            this.latch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(ID3.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ID3Node buildTree(List<DataPointPair<Integer>> list, Set<Integer> set, final ExecutorService executorService) {
        double entropy = entropy(list);
        double size = list.size();
        if (set.isEmpty() || entropy == 0.0d) {
            CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
            for (DataPointPair<Integer> dataPointPair : list) {
                categoricalResults.setProb(dataPointPair.getPair().intValue(), categoricalResults.getProb(dataPointPair.getPair().intValue()) + 1.0d);
            }
            categoricalResults.divideConst(size);
            this.latch.countDown();
            return new ID3Node(categoricalResults);
        }
        int i = -1;
        double d = Double.MIN_VALUE;
        ArrayList arrayList = null;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            ArrayList arrayList2 = new ArrayList(this.attributes[intValue].getNumOfCategories());
            for (int i2 = 0; i2 < this.attributes[intValue].getNumOfCategories(); i2++) {
                arrayList2.add(new ArrayList());
            }
            for (DataPointPair<Integer> dataPointPair2 : list) {
                ((List) arrayList2.get(dataPointPair2.getDataPoint().getCategoricalValue(intValue))).add(dataPointPair2);
            }
            double d2 = 0.0d;
            for (int i3 = 0; i3 < arrayList2.size(); i3++) {
                d2 += (entropy((List) arrayList2.get(i3)) * ((List) arrayList2.get(i3)).size()) / size;
            }
            double d3 = entropy - d2;
            if (d3 > d) {
                i = intValue;
                d = d3;
                arrayList = arrayList2;
            }
        }
        final ID3Node iD3Node = new ID3Node(this.attributes[i].getNumOfCategories(), i);
        final IntSet intSet = new IntSet(set);
        intSet.remove(Integer.valueOf(i));
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            final int i5 = i4;
            final List list2 = (List) arrayList.get(i5);
            this.latch.countUp();
            executorService.submit(new Runnable() { // from class: jsat.classifiers.trees.ID3.1
                @Override // java.lang.Runnable
                public void run() {
                    iD3Node.setNode(i5, ID3.this.buildTree(list2, intSet, executorService));
                }
            });
        }
        this.latch.countDown();
        return iD3Node;
    }

    private double entropy(List<DataPointPair<Integer>> list) {
        if (list.isEmpty()) {
            return 0.0d;
        }
        double[] dArr = new double[this.predicting.getNumOfCategories()];
        Iterator<DataPointPair<Integer>> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().getPair().intValue();
            dArr[intValue] = dArr[intValue] + 1.0d;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / list.size();
        }
        double d = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dArr[i3] != 0.0d) {
                d += dArr[i3] * (Math.log(dArr[i3]) / Math.log(2.0d));
            }
        }
        return Math.abs(d);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m584clone() {
        ID3 id3 = new ID3();
        id3.attributes = this.attributes;
        id3.latch = null;
        id3.predicting = this.predicting;
        id3.root = this.root.copy();
        return id3;
    }
}
