package weka.classifiers.meta;

import edu.stanford.nlp.sequences.SeqClassifierFlags;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.trees.REPTree;
import weka.core.AdditionalMeasureProducer;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.PartitionGenerator;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/meta/Bagging.class */
public class Bagging extends RandomizableParallelIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler, PartitionGenerator, Aggregateable<Bagging> {
    static final long serialVersionUID = -115879962237199703L;
    protected int m_BagSizePercent = 100;
    protected boolean m_CalcOutOfBag = false;
    protected boolean m_RepresentUsingWeights = false;
    protected double m_OutOfBagError;
    protected Random m_random;
    protected boolean[][] m_inBag;
    protected Instances m_data;
    protected List<Classifier> m_classifiersCache;

    public Bagging() {
        this.m_Classifier = new REPTree();
    }

    public String globalInfo() {
        return "Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. \n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1996");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Bagging predictors");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "24");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "2");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "123-140");
        return technicalInformation;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.REPTree";
    }

    @Override // weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(3);
        vector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        vector.addElement(new Option("\tCalculate the out of bag error.", SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL, 0, "-O"));
        vector.addElement(new Option("\tRepresent copies of instances using weights rather than explicitly.", "-represent-copies-using-weights", 0, "-represent-copies-using-weights"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('P', strArr);
        if (option.length() != 0) {
            setBagSizePercent(Integer.parseInt(option));
        } else {
            setBagSizePercent(100);
        }
        setCalcOutOfBag(Utils.getFlag('O', strArr));
        setRepresentCopiesUsingWeights(Utils.getFlag("represent-copies-using-weights", strArr));
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-P");
        vector.add("" + getBagSizePercent());
        if (getCalcOutOfBag()) {
            vector.add("-O");
        }
        if (getRepresentCopiesUsingWeights()) {
            vector.add("-represent-copies-using-weights");
        }
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    public String bagSizePercentTipText() {
        return "Size of each bag, as a percentage of the training set size.";
    }

    public int getBagSizePercent() {
        return this.m_BagSizePercent;
    }

    public void setBagSizePercent(int i) {
        this.m_BagSizePercent = i;
    }

    public String representCopiesUsingWeightsTipText() {
        return "Whether to represent copies of instances using weights rather than explicitly.";
    }

    public void setRepresentCopiesUsingWeights(boolean z) {
        this.m_RepresentUsingWeights = z;
    }

    public boolean getRepresentCopiesUsingWeights() {
        return this.m_RepresentUsingWeights;
    }

    public String calcOutOfBagTipText() {
        return "Whether the out-of-bag error is calculated.";
    }

    public void setCalcOutOfBag(boolean z) {
        this.m_CalcOutOfBag = z;
    }

    public boolean getCalcOutOfBag() {
        return this.m_CalcOutOfBag;
    }

    public double measureOutOfBagError() {
        return this.m_OutOfBagError;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public Enumeration<String> enumerateMeasures() {
        Vector vector = new Vector(1);
        vector.addElement("measureOutOfBagError");
        return vector.elements();
    }

    @Override // weka.core.AdditionalMeasureProducer
    public double getMeasure(String str) {
        if (str.equalsIgnoreCase("measureOutOfBagError")) {
            return measureOutOfBagError();
        }
        throw new IllegalArgumentException(str + " not supported (Bagging)");
    }

    @Override // weka.classifiers.ParallelIteratedSingleClassifierEnhancer
    protected synchronized Instances getTrainingSet(int i) throws Exception {
        Instances resampleWithWeights;
        int numInstances = (this.m_data.numInstances() * this.m_BagSizePercent) / 100;
        Random random = new Random(this.m_Seed + i);
        if (this.m_CalcOutOfBag) {
            this.m_inBag[i] = new boolean[this.m_data.numInstances()];
            resampleWithWeights = this.m_data.resampleWithWeights(random, this.m_inBag[i], getRepresentCopiesUsingWeights());
        } else {
            resampleWithWeights = this.m_data.resampleWithWeights(random, getRepresentCopiesUsingWeights());
            if (numInstances < this.m_data.numInstances()) {
                resampleWithWeights.randomize(random);
                resampleWithWeights = new Instances(resampleWithWeights, 0, numInstances);
            }
        }
        return resampleWithWeights;
    }

    /* JADX WARN: Type inference failed for: r1v62, types: [boolean[], boolean[][]] */
    @Override // weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        double maxIndex;
        getCapabilities().testWithFail(instances);
        if (getRepresentCopiesUsingWeights() && !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new IllegalArgumentException("Cannot represent copies using weights when base learner in bagging does not implement WeightedInstancesHandler.");
        }
        this.m_data = new Instances(instances);
        this.m_data.deleteWithMissingClass();
        super.buildClassifier(this.m_data);
        if (this.m_CalcOutOfBag && this.m_BagSizePercent != 100) {
            throw new IllegalArgumentException("Bag size needs to be 100% if out-of-bag error is to be calculated!");
        }
        this.m_random = new Random(this.m_Seed);
        this.m_inBag = (boolean[][]) null;
        if (this.m_CalcOutOfBag) {
            this.m_inBag = new boolean[this.m_Classifiers.length];
        }
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            if (this.m_Classifier instanceof Randomizable) {
                ((Randomizable) this.m_Classifiers[i]).setSeed(this.m_random.nextInt());
            }
        }
        buildClassifiers();
        if (getCalcOutOfBag()) {
            double d = 0.0d;
            double d2 = 0.0d;
            boolean isNumeric = this.m_data.classAttribute().isNumeric();
            for (int i2 = 0; i2 < this.m_data.numInstances(); i2++) {
                double[] dArr = isNumeric ? new double[1] : new double[this.m_data.numClasses()];
                int i3 = 0;
                for (int i4 = 0; i4 < this.m_Classifiers.length; i4++) {
                    if (!this.m_inBag[i4][i2]) {
                        if (isNumeric) {
                            double classifyInstance = this.m_Classifiers[i4].classifyInstance(this.m_data.instance(i2));
                            if (!Utils.isMissingValue(classifyInstance)) {
                                double[] dArr2 = dArr;
                                dArr2[0] = dArr2[0] + classifyInstance;
                                i3++;
                            }
                        } else {
                            i3++;
                            double[] distributionForInstance = this.m_Classifiers[i4].distributionForInstance(this.m_data.instance(i2));
                            for (int i5 = 0; i5 < distributionForInstance.length; i5++) {
                                double[] dArr3 = dArr;
                                int i6 = i5;
                                dArr3[i6] = dArr3[i6] + distributionForInstance[i5];
                            }
                        }
                    }
                }
                if (isNumeric) {
                    maxIndex = i3 == 0 ? Utils.missingValue() : dArr[0] / i3;
                } else if (Utils.eq(Utils.sum(dArr), KStarConstants.FLOOR)) {
                    maxIndex = Utils.missingValue();
                } else {
                    maxIndex = Utils.maxIndex(dArr);
                    Utils.normalize(dArr);
                }
                if (!Utils.isMissingValue(maxIndex)) {
                    d += this.m_data.instance(i2).weight();
                    if (isNumeric) {
                        d2 += StrictMath.abs(maxIndex - this.m_data.instance(i2).classValue()) * this.m_data.instance(i2).weight();
                    } else if (maxIndex != this.m_data.instance(i2).classValue()) {
                        d2 += this.m_data.instance(i2).weight();
                    }
                }
            }
            if (d > KStarConstants.FLOOR) {
                this.m_OutOfBagError = d2 / d;
            }
        } else {
            this.m_OutOfBagError = KStarConstants.FLOOR;
        }
        this.m_data = null;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[instance.numClasses()];
        double d = 0.0d;
        for (int i = 0; i < this.m_NumIterations; i++) {
            if (instance.classAttribute().isNumeric()) {
                double classifyInstance = this.m_Classifiers[i].classifyInstance(instance);
                if (!Utils.isMissingValue(classifyInstance)) {
                    dArr[0] = dArr[0] + classifyInstance;
                    d += 1.0d;
                }
            } else {
                double[] distributionForInstance = this.m_Classifiers[i].distributionForInstance(instance);
                for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + distributionForInstance[i2];
                }
            }
        }
        if (instance.classAttribute().isNumeric()) {
            if (d == KStarConstants.FLOOR) {
                dArr[0] = Utils.missingValue();
            } else {
                dArr[0] = dArr[0] / d;
            }
            return dArr;
        }
        if (Utils.eq(Utils.sum(dArr), KStarConstants.FLOOR)) {
            return dArr;
        }
        Utils.normalize(dArr);
        return dArr;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "Bagging: No model built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            stringBuffer.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        if (this.m_CalcOutOfBag) {
            stringBuffer.append("Out of bag error: " + Utils.doubleToString(this.m_OutOfBagError, 4) + "\n\n");
        }
        return stringBuffer.toString();
    }

    @Override // weka.core.PartitionGenerator
    public void generatePartition(Instances instances) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        buildClassifier(instances);
    }

    @Override // weka.core.PartitionGenerator
    public double[] getMembershipValues(Instance instance) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            double[] membershipValues = ((PartitionGenerator) this.m_Classifiers[i2]).getMembershipValues(instance);
            i += membershipValues.length;
            arrayList.add(membershipValues);
        }
        double[] dArr = new double[i];
        int i3 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            double[] dArr2 = (double[]) it.next();
            System.arraycopy(dArr2, 0, dArr, i3, dArr2.length);
            i3 += dArr2.length;
        }
        return dArr;
    }

    @Override // weka.core.PartitionGenerator
    public int numElements() throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        int i = 0;
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            i += ((PartitionGenerator) this.m_Classifiers[i2]).numElements();
        }
        return i;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10470 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new Bagging(), strArr);
    }

    @Override // weka.core.Aggregateable
    public Bagging aggregate(Bagging bagging) throws Exception {
        if (!this.m_Classifier.getClass().isAssignableFrom(bagging.m_Classifier.getClass())) {
            throw new Exception("Can't aggregate because base classifiers differ");
        }
        if (this.m_classifiersCache == null) {
            this.m_classifiersCache = new ArrayList();
            this.m_classifiersCache.addAll(Arrays.asList(this.m_Classifiers));
        }
        this.m_classifiersCache.addAll(Arrays.asList(bagging.m_Classifiers));
        return this;
    }

    @Override // weka.core.Aggregateable
    public void finalizeAggregation() throws Exception {
        this.m_Classifiers = (Classifier[]) this.m_classifiersCache.toArray(new Classifier[1]);
        this.m_NumIterations = this.m_Classifiers.length;
        this.m_classifiersCache = null;
    }
}
