KSquare Utilities
KKMLL::CrossValidation Class Reference

A class that is meant to manage a n-Fold Cross Validation. More...

#include <CrossValidation.h>

Public Member Functions

 CrossValidation (TrainingConfiguration2Ptr _config, FeatureVectorListPtr _examples, MLClassListPtr _mlClasses, kkint32 _numOfFolds, bool _featuresAreAlreadyNormalized, FileDescPtr _fileDesc, RunLog &_log, bool &_cancelFlag)
 
 ~CrossValidation ()
 
const VectorFloat & Accuracies () const
 
float Accuracy ()
 
float AccuracyMean () const
 
float AccuracyNorm ()
 
float AccuracyStdDev () const
 
double AvgPredProb () const
 
const ConfusionMatrix2Ptr ConfussionMatrix () const
 
kkint32 DuplicateTrainDataCount () const
 
const VectorFloat & FoldAccuracies () const
 
float FoldAccuracy (kkint32 foldNum) const
 
KKStr FoldAccuracysToStr () const
 
ConfusionMatrix2Ptr GiveMeOwnershipOfConfusionMatrix ()
 
void NumOfFolds (kkint32 _numOfFolds)
 
kkint32 NumOfSupportVectors () const
 
kkint32 NumSVs () const
 
void RunCrossValidation (RunLog &log)
 
void RunValidationOnly (FeatureVectorListPtr validationData, bool *classedCorrectly, RunLog &log)
 
const VectorFloat & SupportPoints () const
 
double SupportPointsMean () const
 
double SupportPointsStdDev () const
 
kkint32 SupportPointsTotal () const
 
double TestTimeMean () const
 
const VectorDouble & TestTimes () const
 
double TestTimeStdDev () const
 
double TestTimeTotal () const
 
kkint32 TotalNumSVs () const
 
double TrainTimeMean () const
 
const VectorDouble & TrainTimes () const
 
double TrainTimeStdDev () const
 
double TrainTimeTotal () const
 

Detailed Description

A class that is meant to manage a n-Fold Cross Validation.

Author
Kurt Kramer
*
*
*
-------------------------------------------------------------------*
History *
*
Date Programmer Description *
---------- ----------- -----------------------------------------*
2004 Kurt Kramer Original Development. *
*
*
2005-01-07 Kurt Kramer Added classedCorrectly parameter to *
CrossValidate. If not null it should *
point to an array of bool that has as *
many elements as there are in the *
testImages list. Each element represents *
weather the corresponding element in *
testImages was classified correctly. *

Definition at line 82 of file CrossValidation.h.

Constructor & Destructor Documentation

CrossValidation::CrossValidation ( TrainingConfiguration2Ptr  _config,
FeatureVectorListPtr  _examples,
MLClassListPtr  _mlClasses,
kkint32  _numOfFolds,
bool  _featuresAreAlreadyNormalized,
FileDescPtr  _fileDesc,
RunLog _log,
bool &  _cancelFlag 
)

Definition at line 32 of file CrossValidation.cpp.

References KKMLL::TrainingConfiguration2::ExamplesPerClass(), KKMLL::FeatureVectorList::ExtractExamplesForClassList(), and KKMLL::TrainingConfiguration2::FvFactoryProducer().

Referenced by KKMLL::CrossValidationMxN::RunTrainAndTest(), and KKMLL::CrossValidationMxN::RunValidations().

40  :
41 
42  cancelFlag (_cancelFlag),
43  config (_config),
44  duplicateTrainDataCount (0),
45  featuresAreAlreadyNormalized (_featuresAreAlreadyNormalized),
46  fileDesc (_fileDesc),
47  foldAccuracies (),
48  foldCounts (),
49  fvProducerFactory (NULL),
50  confusionMatrix (NULL),
51  cmByNumOfConflicts (NULL),
52  examples (NULL),
53  mlClasses (_mlClasses),
54  imagesPerClass (0),
55  maxNumOfConflicts (0),
56  numOfFolds (_numOfFolds),
57  numSVs (0),
58  totalNumSVs (0),
59  numOfWinnersCounts (NULL),
60  numOfWinnersCorrects (NULL),
61  numOfWinnersOneOfTheWinners (NULL),
62  testTime (0.0),
63  trainingTime (0.0),
64 
65  accuracyMean (0.0f),
66  accuracyStdDev (0.0f),
67 
68  avgPredProb (0.0),
69  totalPredProb (0.0),
70 
71  supportPointsMean (0.0f),
72  supportPointsStdDev (0.0f),
73  supportPoints (),
74 
75  testTimeMean (0.0),
76  testTimeStdDev (0.0),
77  testTimes (),
78 
79  trainTimeMean (0.0),
80  trainTimeStdDev (0.0),
81  trainTimes (),
82 
83  weOwnConfusionMatrix (false)
84 
85 {
86  fvProducerFactory = config->FvFactoryProducer (_log);
87  examples = _examples->ExtractExamplesForClassList (mlClasses);
88  if (config)
89  imagesPerClass = config->ExamplesPerClass ();
90  else
91  imagesPerClass = -1;
92 }
FeatureVectorListPtr ExtractExamplesForClassList(MLClassListPtr classes)
FactoryFVProducerPtr FvFactoryProducer(RunLog &log) const
CrossValidation::~CrossValidation ( )

Definition at line 96 of file CrossValidation.cpp.

97 {
98  DeleteAllocatedMemory ();
99  delete examples; examples = NULL;
100 }

Member Function Documentation

const VectorFloat& KKMLL::CrossValidation::Accuracies ( ) const
inline

Definition at line 127 of file CrossValidation.h.

127 {return foldAccuracies;}
float CrossValidation::Accuracy ( )

Definition at line 456 of file CrossValidation.cpp.

References KKMLL::ConfusionMatrix2::Accuracy().

457 {
458  if (confusionMatrix)
459  return (float)confusionMatrix->Accuracy ();
460  else
461  return 0.0f;
462 } /* Accuracy */
float KKMLL::CrossValidation::AccuracyMean ( ) const
inline

Definition at line 128 of file CrossValidation.h.

128 {return accuracyMean;}
float CrossValidation::AccuracyNorm ( )

Definition at line 466 of file CrossValidation.cpp.

References KKMLL::ConfusionMatrix2::AccuracyNorm().

467 {
468  if (confusionMatrix)
469  return (float)confusionMatrix->AccuracyNorm ();
470  else
471  return 0.0f;
472 } /* Accuracy */
float KKMLL::CrossValidation::AccuracyStdDev ( ) const
inline

Definition at line 129 of file CrossValidation.h.

129 {return accuracyStdDev;}
double KKMLL::CrossValidation::AvgPredProb ( ) const
inline

Definition at line 131 of file CrossValidation.h.

131 {return avgPredProb;}
const ConfusionMatrix2Ptr KKMLL::CrossValidation::ConfussionMatrix ( ) const
inline

Definition at line 105 of file CrossValidation.h.

Referenced by KKMLL::CrossValidationMxN::RunTrainAndTest(), and KKMLL::CrossValidationMxN::RunValidations().

105 {return confusionMatrix;}
kkint32 KKMLL::CrossValidation::DuplicateTrainDataCount ( ) const
inline

Definition at line 109 of file CrossValidation.h.

109 {return duplicateTrainDataCount;}
const VectorFloat& KKMLL::CrossValidation::FoldAccuracies ( ) const
inline

Definition at line 115 of file CrossValidation.h.

115 {return foldAccuracies;}
float CrossValidation::FoldAccuracy ( kkint32  foldNum) const

Definition at line 495 of file CrossValidation.cpp.

496 {
497  if ((foldNum < 0) || (foldNum >= (kkint32)foldAccuracies.size ()))
498  {
499  return 0.0f;
500  }
501 
502  return foldAccuracies[foldNum];
503 } /* FoldAccuracy */
__int32 kkint32
Definition: KKBaseTypes.h:88
KKStr CrossValidation::FoldAccuracysToStr ( ) const

Definition at line 477 of file CrossValidation.cpp.

References KKB::KKStr::Concat(), and KKB::KKStr::KKStr().

478 {
479  KKStr foldAccuracyStr (9 * numOfFolds); // Pre Reserving enough space for all Accuracies.
480 
481  for (kkuint32 foldNum = 0; foldNum < foldAccuracies.size (); foldNum++)
482  {
483  if (foldNum > 0)
484  foldAccuracyStr << "\t";
485  foldAccuracyStr << StrFormatDouble (foldAccuracies[foldNum], "ZZ,ZZ0.00%");
486  }
487 
488  return foldAccuracyStr;
489 } /* FoldAccuracysToStr */
unsigned __int32 kkuint32
Definition: KKBaseTypes.h:89
KKStr StrFormatDouble(double val, const char *mask)
Definition: KKStr.cpp:4819
ConfusionMatrix2Ptr CrossValidation::GiveMeOwnershipOfConfusionMatrix ( )

Definition at line 508 of file CrossValidation.cpp.

509 {
510  weOwnConfusionMatrix = false;
511  return confusionMatrix;
512 }
void KKMLL::CrossValidation::NumOfFolds ( kkint32  _numOfFolds)
inline

Definition at line 112 of file CrossValidation.h.

112 {numOfFolds = _numOfFolds;}
kkint32 KKMLL::CrossValidation::NumOfSupportVectors ( ) const
inline

Definition at line 121 of file CrossValidation.h.

121 {return numSVs;}
kkint32 KKMLL::CrossValidation::NumSVs ( ) const
inline

Definition at line 122 of file CrossValidation.h.

122 {return numSVs;}
void CrossValidation::RunCrossValidation ( RunLog log)

Definition at line 166 of file CrossValidation.cpp.

References KKMLL::FeatureVectorList::ManufactureEmptyList(), and KKMLL::FeatureVectorList::PushOnBack().

Referenced by KKMLL::CrossValidationMxN::RunValidations().

167 {
168  log.Level (10) << "CrossValidation::RunCrossValidation numOfFolds[" << numOfFolds << "]" << endl;
169 
170  if (numOfFolds < 1)
171  {
172  log.Level (-1) << endl
173  << "CrossValidation::RunCrossValidation **** ERROR ****" << endl
174  << endl
175  << " Invalid numOfFolds[" << numOfFolds << "]." << endl
176  << endl;
177  return;
178  }
179 
180  DeleteAllocatedMemory ();
181  AllocateMemory ();
182 
183  kkint32 imageCount = examples->QueueSize ();
184  kkint32 numImagesPerFold = (imageCount + numOfFolds - 1) / numOfFolds;
185  kkint32 firstInGroup = 0;
186 
187  totalPredProb = 0.0;
188 
189 
190  kkint32 foldNum;
191 
192  for (foldNum = 0; foldNum < numOfFolds; foldNum++)
193  {
194  kkint32 lastInGroup;
195 
196  // If We are doing the last Fold Make sure that we are including all the examples
197  // that have not been tested.
198  if (foldNum == (numOfFolds - 1))
199  lastInGroup = imageCount;
200  else
201  lastInGroup = firstInGroup + numImagesPerFold - 1;
202 
203 
204  log.Level (20) << "Fold [" << (foldNum + 1) << "] of [" << numOfFolds << "]" << endl;
205 
206  FeatureVectorListPtr trainingExamples = examples->ManufactureEmptyList (true);
207  FeatureVectorListPtr testImages = examples->ManufactureEmptyList (true);
208 
209  log.Level (30) << "Fold Num[" << foldNum << "] "
210  << "FirstTestImage[" << firstInGroup << "] "
211  << "LastInGroup[" << lastInGroup << "]."
212  << endl;
213 
214  for (kkint32 x = 0; (x < imageCount) && (!cancelFlag); x++)
215  {
216  FeatureVectorPtr newImage = examples->IdxToPtr (x)->Duplicate ();
217  if ((x >= firstInGroup) && (x <= lastInGroup))
218  {
219  testImages->PushOnBack (newImage);
220  }
221  else
222  {
223  trainingExamples->PushOnBack (newImage);
224  }
225  }
226 
227  log.Level (20) << "Number Of Training Images : " << trainingExamples->QueueSize () << endl;
228  log.Level (20) << "Number Of Test Images : " << testImages->QueueSize () << endl;
229 
230  if (cancelFlag)
231  break;
232 
233  CrossValidate (testImages, trainingExamples, foldNum, NULL, log);
234 
235  delete trainingExamples; trainingExamples = NULL;
236  delete testImages; testImages = NULL;
237 
238  firstInGroup = firstInGroup + numImagesPerFold;
239  }
240 
241  if (!cancelFlag)
242  {
243  avgPredProb = totalPredProb / imageCount;
244 
245  CalcMeanAndStdDev (foldAccuracies, accuracyMean, accuracyStdDev);
246  CalcMeanAndStdDev (supportPoints, supportPointsMean, supportPointsStdDev);
247  CalcMeanAndStdDev (testTimes, testTimeMean, testTimeStdDev);
248  CalcMeanAndStdDev (trainTimes, trainTimeMean, trainTimeStdDev);
249  }
250 } /* RunCrossValidation */
HTMLReport &__cdecl endl(HTMLReport &htmlReport)
Definition: HTMLReport.cpp:240
void PushOnBack(FeatureVectorPtr image)
Overloading the PushOnBack function in KKQueue so we can monitor the Version and Sort Order...
__int32 kkint32
Definition: KKBaseTypes.h:88
EntryPtr IdxToPtr(kkuint32 idx) const
Definition: KKQueue.h:732
virtual FeatureVectorListPtr ManufactureEmptyList(bool _owner) const
Creates an instance of a Empty FeatureVectorList.
RunLog & Level(kkint32 _level)
Definition: RunLog.cpp:220
virtual FeatureVectorPtr Duplicate() const
Container class for FeatureVector derived objects.
void CalcMeanAndStdDev(const vector< T > &v, T &mean, T &stdDev)
kkint32 QueueSize() const
Definition: KKQueue.h:313
Represents a Feature Vector of a single example, labeled or unlabeled.
Definition: FeatureVector.h:59
void CrossValidation::RunValidationOnly ( FeatureVectorListPtr  validationData,
bool *  classedCorrectly,
RunLog log 
)

Definition at line 256 of file CrossValidation.cpp.

References KKMLL::FeatureVectorList::DuplicateListAndContents().

Referenced by KKMLL::CrossValidationMxN::RunTrainAndTest().

260 {
261  log.Level (10) << "CrossValidation::RunValidationOnly" << endl;
262  DeleteAllocatedMemory ();
263  AllocateMemory ();
264 
265  totalPredProb = 0.0;
266 
267  // We need to get a duplicate copy of each image data because the trainer and classifier
268  // will normalize the data.
269  FeatureVectorListPtr trainingExamples = examples->DuplicateListAndContents ();
270  FeatureVectorListPtr testImages = validationData->DuplicateListAndContents ();
271 
272  CrossValidate (testImages, trainingExamples, 0, classedCorrectly, log);
273 
274  if (testImages->QueueSize () > 0)
275  avgPredProb = totalPredProb / testImages->QueueSize ();
276  else
277  avgPredProb = 0.0f;
278 
279  delete trainingExamples; trainingExamples = NULL;
280  delete testImages; testImages = NULL;
281 
282 
283  if (!cancelFlag)
284  {
285  CalcMeanAndStdDev (foldAccuracies, accuracyMean, accuracyStdDev);
286  CalcMeanAndStdDev (supportPoints, supportPointsMean, supportPointsStdDev);
287  CalcMeanAndStdDev (testTimes, testTimeMean, testTimeStdDev);
288  CalcMeanAndStdDev (trainTimes, trainTimeMean, trainTimeStdDev);
289  }
290 } /* RunValidationOnly */
HTMLReport &__cdecl endl(HTMLReport &htmlReport)
Definition: HTMLReport.cpp:240
virtual FeatureVectorListPtr DuplicateListAndContents() const
Creates a duplicate of list and also duplicates it contents.
RunLog & Level(kkint32 _level)
Definition: RunLog.cpp:220
Container class for FeatureVector derived objects.
void CalcMeanAndStdDev(const vector< T > &v, T &mean, T &stdDev)
kkint32 QueueSize() const
Definition: KKQueue.h:313
const VectorFloat& KKMLL::CrossValidation::SupportPoints ( ) const
inline

Definition at line 133 of file CrossValidation.h.

133 {return supportPoints;}
double KKMLL::CrossValidation::SupportPointsMean ( ) const
inline

Definition at line 134 of file CrossValidation.h.

134 {return supportPointsMean;}
double KKMLL::CrossValidation::SupportPointsStdDev ( ) const
inline

Definition at line 135 of file CrossValidation.h.

135 {return supportPointsStdDev;}
kkint32 KKMLL::CrossValidation::SupportPointsTotal ( ) const
inline

Definition at line 125 of file CrossValidation.h.

125 {return numSVs;}
double KKMLL::CrossValidation::TestTimeMean ( ) const
inline

Definition at line 138 of file CrossValidation.h.

138 {return testTimeMean;}
const VectorDouble& KKMLL::CrossValidation::TestTimes ( ) const
inline

Definition at line 137 of file CrossValidation.h.

137 {return testTimes;}
double KKMLL::CrossValidation::TestTimeStdDev ( ) const
inline

Definition at line 139 of file CrossValidation.h.

139 {return testTimeStdDev;}
double KKMLL::CrossValidation::TestTimeTotal ( ) const
inline

Definition at line 140 of file CrossValidation.h.

140 {return testTime;}
kkint32 KKMLL::CrossValidation::TotalNumSVs ( ) const
inline

Definition at line 123 of file CrossValidation.h.

123 {return totalNumSVs;}
double KKMLL::CrossValidation::TrainTimeMean ( ) const
inline

Definition at line 143 of file CrossValidation.h.

143 {return trainTimeMean;}
const VectorDouble& KKMLL::CrossValidation::TrainTimes ( ) const
inline

Definition at line 142 of file CrossValidation.h.

142 {return trainTimes;}
double KKMLL::CrossValidation::TrainTimeStdDev ( ) const
inline

Definition at line 144 of file CrossValidation.h.

144 {return trainTimeStdDev;}
double KKMLL::CrossValidation::TrainTimeTotal ( ) const
inline

Definition at line 145 of file CrossValidation.h.

145 {return trainingTime;}

The documentation for this class was generated from the following files: