- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
Hi,
I need to use binary confusion matrix, but it only works for me with SVM.
When I use it with Naive Bayes or Adaboost I get an error when I try to set predicted labels or ground truth labels to this input: input = qualityMetricSet.getInputDataCollection().getInput(quality_metric_set.confusionMatrix)
What am I doing wrong?
Mariana
Link Copied
- Mark as New
- Bookmark
- Subscribe
- Mute
- Subscribe to RSS Feed
- Permalink
- Report Inappropriate Content
Hello Mariana,
Here is an example of binary confusion matrix computation for AdaBoost algorithm:
/* // Content: // Java example of AdaBoost quality metrics //////////////////////////////////////////////////////////////////////////////// */ package com.intel.daal.examples.quality_metrics; import java.nio.DoubleBuffer; import com.intel.daal.algorithms.adaboost.Model; import com.intel.daal.algorithms.adaboost.prediction.*; import com.intel.daal.algorithms.adaboost.training.*; import com.intel.daal.algorithms.adaboost.quality_metric_set.*; import com.intel.daal.algorithms.classifier.prediction.ModelInputId; import com.intel.daal.algorithms.classifier.prediction.NumericTableInputId; import com.intel.daal.algorithms.classifier.prediction.PredictionResult; import com.intel.daal.algorithms.classifier.prediction.PredictionResultId; import com.intel.daal.algorithms.classifier.quality_metric.binary_confusion_matrix.*; import com.intel.daal.algorithms.classifier.training.InputId; import com.intel.daal.algorithms.classifier.training.TrainingResultId; import com.intel.daal.data_management.data.NumericTable; import com.intel.daal.data_management.data.HomogenNumericTable; import com.intel.daal.data_management.data.MergedNumericTable; import com.intel.daal.data_management.data_source.DataSource; import com.intel.daal.data_management.data_source.FileDataSource; import com.intel.daal.examples.utils.Service; import com.intel.daal.services.DaalContext; class AdaBoostTwoClassQualityMetricSetBatch { /* Input data set parameters */ private static final String trainDataset = "../data/batch/adaboost_train.csv"; private static final String testDataset = "../data/batch/adaboost_test.csv"; private static final int nFeatures = 20; private static TrainingResult trainingResult; private static PredictionResult predictionResult; private static ResultCollection qualityMetricSetResult; private static NumericTable groundTruthLabels; private static NumericTable predictedLabels; private static DaalContext context = new DaalContext(); public static void main(String[] args) throws java.io.FileNotFoundException, java.io.IOException { trainModel(); testModel(); testModelQuality(); printResults(); context.dispose(); } private static void trainModel() { /* Retrieve data from the input data sets */ FileDataSource trainDataSource = new FileDataSource(context, trainDataset, DataSource.DictionaryCreationFlag.DoDictionaryFromContext, DataSource.NumericTableAllocationFlag.NotAllocateNumericTable); /* Create Numeric Tables for training data and labels */ NumericTable trainData = new HomogenNumericTable(context, Double.class, nFeatures, 0, NumericTable.AllocationFlag.NotAllocate); NumericTable trainGroundTruth = new HomogenNumericTable(context, Double.class, 1, 0, NumericTable.AllocationFlag.NotAllocate); MergedNumericTable mergedData = new MergedNumericTable(context); mergedData.addNumericTable(trainData); mergedData.addNumericTable(trainGroundTruth); /* Retrieve the data from an input file */ trainDataSource.loadDataBlock(mergedData); /* Create algorithm objects to train the AdaBoost model */ TrainingBatch algorithm = new TrainingBatch(context, Double.class, TrainingMethod.defaultDense); /* Pass a training data set and dependent values to the algorithm */ algorithm.input.set(InputId.data, trainData); algorithm.input.set(InputId.labels, trainGroundTruth); /* Train the AdaBoost model */ trainingResult = algorithm.compute(); } private static void testModel() { FileDataSource testDataSource = new FileDataSource(context, testDataset, DataSource.DictionaryCreationFlag.DoDictionaryFromContext, DataSource.NumericTableAllocationFlag.NotAllocateNumericTable); /* Create Numeric Tables for testing data and labels */ NumericTable testData = new HomogenNumericTable(context, Double.class, nFeatures, 0, NumericTable.AllocationFlag.NotAllocate); groundTruthLabels = new HomogenNumericTable(context, Double.class, 1, 0, NumericTable.AllocationFlag.NotAllocate); MergedNumericTable mergedData = new MergedNumericTable(context); mergedData.addNumericTable(testData); mergedData.addNumericTable(groundTruthLabels); /* Retrieve the data from an input file */ testDataSource.loadDataBlock(mergedData); /* Create algorithm objects for AdaBoost prediction with the fast method */ PredictionBatch algorithm = new PredictionBatch(context, Double.class, PredictionMethod.defaultDense); /* Pass a testing data set and the trained model to the algorithm */ Model model = trainingResult.get(TrainingResultId.model); algorithm.input.set(NumericTableInputId.data, testData); algorithm.input.set(ModelInputId.model, model); /* Compute prediction results */ predictionResult = algorithm.compute(); } private static void testModelQuality() { /* Retrieve predicted labels */ predictedLabels = predictionResult.get(PredictionResultId.prediction); /* Create a quality metric set object to compute quality metrics of the AdaBoost algorithm */ QualityMetricSetBatch quality_metric_set = new QualityMetricSetBatch(context); BinaryConfusionMatrixInput input = quality_metric_set.getInputDataCollection() .getInput(QualityMetricId.confusionMatrix); input.set(BinaryConfusionMatrixInputId.predictedLabels, predictedLabels); input.set(BinaryConfusionMatrixInputId.groundTruthLabels, groundTruthLabels); /* Compute quality metrics */ qualityMetricSetResult = quality_metric_set.compute(); } private static void printResults() { /* Print the classification results */ Service.printClassificationResult(groundTruthLabels, predictedLabels, "Ground truth", "Classification results", "AdaBoost classification results (first 20 observations):", 20); /* Print the quality metrics */ BinaryConfusionMatrixResult qualityMetricResult = qualityMetricSetResult .getResult(QualityMetricId.confusionMatrix); NumericTable confusionMatrix = qualityMetricResult.get(BinaryConfusionMatrixResultId.confusionMatrix); NumericTable binaryMetrics = qualityMetricResult.get(BinaryConfusionMatrixResultId.binaryMetrics); Service.printNumericTable("Confusion matrix:", confusionMatrix); DoubleBuffer qualityMetricsData = DoubleBuffer .allocate((int) (binaryMetrics.getNumberOfColumns() * binaryMetrics.getNumberOfRows())); qualityMetricsData = binaryMetrics.getBlockOfRows(0, binaryMetrics.getNumberOfRows(), qualityMetricsData); System.out.println("Accuracy: " + qualityMetricsData.get(BinaryMetricId.accuracy.getValue())); System.out.println("Precision: " + qualityMetricsData.get(BinaryMetricId.precision.getValue())); System.out.println("Recall: " + qualityMetricsData.get(BinaryMetricId.recall.getValue())); System.out.println("F-score: " + qualityMetricsData.get(BinaryMetricId.fscore.getValue())); System.out.println("Specificity: " + qualityMetricsData.get(BinaryMetricId.specificity.getValue())); System.out.println("AUC: " + qualityMetricsData.get(BinaryMetricId.AUC.getValue())); } }
It works fine and produces following output:
Confusion matrix:
401.000 0.000
0.000 1599.000
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F-score: 1.0
Specificity: 1.0
AUC: 1.0
Regarding Naive Bayes classifier. It is a multi-class classifier and binary confusion matrix could not be used with Naive Bayer algorithm. Please use multi-class confusion matrix instead. Please refer to SVMMulticlassQualityMetricSetBatchExample.java example for the details.
Best regards,
Victoriya
- Subscribe to RSS Feed
- Mark Topic as New
- Mark Topic as Read
- Float this Topic for Current User
- Bookmark
- Subscribe
- Printer Friendly Page