KSquare Utilities
CrossValidation.h
Go to the documentation of this file.
1 #ifndef _CROSSVALIDATION_
2 #define _CROSSVALIDATION_
3 
4 /**
5  @class KKMLL::CrossValidation
6  @brief A class that is meant to manage a n-Fold Cross Validation.
7  @author Kurt Kramer
8  @details
9  @code
10  *********************************************************************
11  * CrossValidation *
12  * *
13  * *
14  * *
15  *-------------------------------------------------------------------*
16  * History *
17  * *
18  * Date Programmer Description *
19  * ---------- ----------- -----------------------------------------*
20  * 2004 Kurt Kramer Original Development. *
21  * *
22  * *
23  * 2005-01-07 Kurt Kramer Added classedCorrectly parameter to *
24  * CrossValidate. If not null it should *
25  * point to an array of bool that has as *
26  * many elements as there are in the *
27  * testImages list. Each element represents *
28  * weather the corresponding element in *
29  * testImages was classified correctly. *
30  *********************************************************************
31  @endcode
32  */
33 
34 
35 #include "KKBaseTypes.h"
36 #include "RunLog.h"
37 
38 
39 namespace KKMLL
40 {
41  #ifndef _FeatureVector_Defined_
42  class FeatureVectorList;
43  typedef FeatureVectorList* FeatureVectorListPtr;
44  #endif
45 
46 
47  #if !defined (_MLCLASS_)
48  class MLClass;
49  typedef MLClass* MLClassPtr;
50  class MLClassList;
51  typedef MLClassList* MLClassListPtr;
52  #endif
53 
54 
55 
56  #ifndef _ConfussionMatrix2_
57  class ConfusionMatrix2;
58  typedef ConfusionMatrix2* ConfusionMatrix2Ptr;
59  #endif
60 
61 
62 
63  #if !defined(_FactoryFVProducer_Defined_)
64  class FactoryFVProducer;
65  typedef FactoryFVProducer* FactoryFVProducerPtr;
66  #endif
67 
68 
69 
70  #ifndef _TrainingConfiguration2_Defined_
72  typedef TrainingConfiguration2* TrainingConfiguration2Ptr;
73  #endif
74 
75 
76  #ifndef _FILEDESC_
77  class FileDesc;
78  typedef FileDesc* FileDescPtr;
79  #endif
80 
81 
83  {
84  public:
85  CrossValidation (TrainingConfiguration2Ptr _config,
86  FeatureVectorListPtr _examples,
87  MLClassListPtr _mlClasses,
88  kkint32 _numOfFolds,
89  bool _featuresAreAlreadyNormalized,
90  FileDescPtr _fileDesc,
91  RunLog& _log,
92  bool& _cancelFlag
93  );
94 
95  ~CrossValidation ();
96 
97  void RunCrossValidation (RunLog& log);
98 
99  void RunValidationOnly (FeatureVectorListPtr validationData,
100  bool* classedCorrectly,
101  RunLog& log
102  );
103 
104  const
105  ConfusionMatrix2Ptr ConfussionMatrix () const {return confusionMatrix;}
106 
107  float Accuracy ();
108  float AccuracyNorm ();
109  kkint32 DuplicateTrainDataCount () const {return duplicateTrainDataCount;}
110  float FoldAccuracy (kkint32 foldNum) const;
111 
112  void NumOfFolds (kkint32 _numOfFolds) {numOfFolds = _numOfFolds;}
113 
114  const
115  VectorFloat& FoldAccuracies () const {return foldAccuracies;}
116 
117  KKStr FoldAccuracysToStr () const;
118 
119  ConfusionMatrix2Ptr GiveMeOwnershipOfConfusionMatrix ();
120 
121  kkint32 NumOfSupportVectors () const {return numSVs;}
122  kkint32 NumSVs () const {return numSVs;}
123  kkint32 TotalNumSVs () const {return totalNumSVs;}
124 
125  kkint32 SupportPointsTotal () const {return numSVs;}
126 
127  const VectorFloat& Accuracies () const {return foldAccuracies;}
128  float AccuracyMean () const {return accuracyMean;}
129  float AccuracyStdDev () const {return accuracyStdDev;}
130 
131  double AvgPredProb () const {return avgPredProb;}
132 
133  const VectorFloat& SupportPoints () const {return supportPoints;}
134  double SupportPointsMean () const {return supportPointsMean;}
135  double SupportPointsStdDev () const {return supportPointsStdDev;}
136 
137  const VectorDouble& TestTimes () const {return testTimes;}
138  double TestTimeMean () const {return testTimeMean;}
139  double TestTimeStdDev () const {return testTimeStdDev;}
140  double TestTimeTotal () const {return testTime;}
141 
142  const VectorDouble& TrainTimes () const {return trainTimes;}
143  double TrainTimeMean () const {return trainTimeMean;}
144  double TrainTimeStdDev () const {return trainTimeStdDev;}
145  double TrainTimeTotal () const {return trainingTime;}
146 
147  private:
148  void AllocateMemory ();
149 
150  void CrossValidate (FeatureVectorListPtr testImages,
151  FeatureVectorListPtr trainingExamples,
152  kkint32 foldNum,
153  bool* classedCorrectly,
154  RunLog& log
155  );
156 
157  void DeleteAllocatedMemory ();
158 
159  //void DistributesImagesRandomlyFromEachWithInFolds ();
160 
161 
162  bool cancelFlag;
163  TrainingConfiguration2Ptr config;
164  kkint32 duplicateTrainDataCount;
165  FactoryFVProducerPtr fvProducerFactory;
166  bool featuresAreAlreadyNormalized;
167  FileDescPtr fileDesc;
168  VectorFloat foldAccuracies;
169  VectorInt foldCounts;
170  ConfusionMatrix2Ptr confusionMatrix;
171  ConfusionMatrix2Ptr* cmByNumOfConflicts;
172  FeatureVectorListPtr examples;
173  MLClassListPtr mlClasses;
174  kkint32 imagesPerClass;
175  kkint32 maxNumOfConflicts; /**< Will indicate the number confusionMatrices created in table in cmByNumOfConflicts; */
176  kkint32 numOfFolds;
177 
178  kkint32 numSVs; /**< Total Support Vectors Detected. */
179 
180  kkint32 totalNumSVs; /**< This is different from 'numOfSupportVectors' it will reflect all the Support Vectors
181  * that are created in a Multi Class SVM. That is if a given example is used in three
182  * different binary classifiers it will be counted three times.
183  */
184 
185  kkint32* numOfWinnersCounts;
186  kkint32* numOfWinnersCorrects;
187  kkint32* numOfWinnersOneOfTheWinners;
188 
189  double testTime;
190  double trainingTime;
191 
192  float accuracyMean;
193  float accuracyStdDev;
194 
195  double avgPredProb;
196  double totalPredProb;
197 
198  float supportPointsMean;
199  float supportPointsStdDev;
200  VectorFloat supportPoints;
201 
202  double testTimeMean;
203  double testTimeStdDev;
204  VectorDouble testTimes;
205 
206  double trainTimeMean;
207  double trainTimeStdDev;
208  VectorDouble trainTimes;
209 
210  bool weOwnConfusionMatrix;
211  };
212 
214 
215 } /* namespace KKMLL */
216 
217 #endif
ConfusionMatrix2Ptr GiveMeOwnershipOfConfusionMatrix()
Provides a detailed description of the attributes of a dataset.
Definition: FileDesc.h:72
double AvgPredProb() const
double SupportPointsStdDev() const
__int32 kkint32
Definition: KKBaseTypes.h:88
double TrainTimeStdDev() const
kkint32 SupportPointsTotal() const
const VectorDouble & TrainTimes() const
const VectorFloat & SupportPoints() const
double TrainTimeMean() const
const ConfusionMatrix2Ptr ConfussionMatrix() const
std::vector< int > VectorInt
Definition: KKBaseTypes.h:138
Represents a "Class" in the Machine Learning Sense.
Definition: MLClass.h:52
float AccuracyStdDev() const
A class that is meant to manage a n-Fold Cross Validation.
void RunCrossValidation(RunLog &log)
kkint32 TotalNumSVs() const
kkint32 NumSVs() const
float AccuracyMean() const
const VectorDouble & TestTimes() const
double TestTimeStdDev() const
kkint32 DuplicateTrainDataCount() const
Container class for FeatureVector derived objects.
const VectorFloat & FoldAccuracies() const
double SupportPointsMean() const
double TrainTimeTotal() const
KKStr FoldAccuracysToStr() const
std::vector< float > VectorFloat
Definition: KKBaseTypes.h:149
static KKStr Concat(const std::vector< std::string > &values)
Concatenates the list of &#39;std::string&#39; strings.
Definition: KKStr.cpp:1082
void NumOfFolds(kkint32 _numOfFolds)
const VectorFloat & Accuracies() const
double TestTimeMean() const
CrossValidation * CrossValidationPtr
double TestTimeTotal() const
float FoldAccuracy(kkint32 foldNum) const
Used for logging messages.
Definition: RunLog.h:49
void EncodeProblem(const struct svm_paramater &param, struct svm_problem &prob_in, struct svm_problem &prob_out)
Responsible for creating a FeatureFectorProducer instance.
Maintains a list of MLClass instances.
Definition: MLClass.h:233
CrossValidation(TrainingConfiguration2Ptr _config, FeatureVectorListPtr _examples, MLClassListPtr _mlClasses, kkint32 _numOfFolds, bool _featuresAreAlreadyNormalized, FileDescPtr _fileDesc, RunLog &_log, bool &_cancelFlag)
kkint32 NumOfSupportVectors() const
A confusion matrix object that is used to record the results from a CrossValidation. <see also cref="CrossValidation"
void RunValidationOnly(FeatureVectorListPtr validationData, bool *classedCorrectly, RunLog &log)
std::vector< double > VectorDouble
Vector of doubles.
Definition: KKBaseTypes.h:148