degate  0.1.2
adaboost.hpp
Go to the documentation of this file.
00001 /*
00002 
00003   This code is from http://terpconnect.umd.edu/~xliu10/research/adaboost.html
00004 
00005   Todo: Check licence type
00006 
00007  */
00008 
00009 #ifndef ADABOOST_HPP
00010 #define ADABOOST_HPP
00011 
00012 #include <vector>
00013 #include <string>
00014 #include <math.h>
00015 
00016 // that is bad
00017 using namespace std;
00018 
00019 template <class T>
00020 class Classifier
00021 {
00022 public:
00023         // This function performs the actual recognition
00024         // MUST be implemented by the weak classifier, usually T is the feature vector
00025         virtual int recognize(T&) = 0;
00026         // MUST be implemented by the weak classifier, simply return the name of the weak classifier itself
00027         // It is recommended to use this function to keep track of the weak classifiers.
00028         // You will find this useful if more than 30 weak classifiers are trained
00029         virtual string get_name() const = 0;
00030         // the ada-boost algorithm that trains the strong classifier from weak classifiers
00031         // data and label defines the training set
00032         // clsfrs is a collection of weak classifiers
00033         // this ada-boost implementation will first run the weak classifiers against all the training samples
00034         // and therefore the acutal trainning will be very fast 
00035         static vector<float> adaboost(vector<Classifier<T>*> clsfrs, vector<T*> data, vector<int> label, const int maxround = 80)
00036         {
00037                 vector<float> alpha;
00038                 vector<float> d;
00039                 
00040                 if (data.size()!=label.size() || clsfrs.size()==0 || label.size()==0)
00041                         return alpha;
00042                 
00043                 d.resize(label.size());
00044                 alpha.resize(clsfrs.size());
00045 
00046 
00047                 for (unsigned int i=0;i<label.size();i++)
00048                         d[i]=float(1.0)/float(label.size());
00049                 vector< vector<int> > rec;
00050                 rec.resize(clsfrs.size());
00051 
00052                 // run the weak classifiers on all the trainning data first
00053                 for (unsigned int j=0;j<clsfrs.size();j++)
00054                 {
00055                         rec[j].resize(label.size());
00056                         for (unsigned int i=0;i<label.size();i++)
00057                                 rec[j][i]=clsfrs[j]->recognize(*data[i]);
00058                 }
00059 
00060                 //run maxround times of iteration
00061                 
00062                 for (int round=0;round<maxround;round++)
00063                 {
00064                         float minerr=(float)label.size();
00065                         int best = 0;
00066                         for (unsigned int j=0;j<clsfrs.size();j++) 
00067                         {
00068                                 float err=0;
00069                                 for (unsigned int i=0;i<label.size();i++)
00070                                 {
00071                                         if (rec[j][i]!=label[i])
00072                                         err += d[i];
00073                                 }
00074                                 if (err<minerr)
00075                                 {
00076                                         minerr = err;
00077                                         best = j;
00078                                 }
00079                         }
00080                         if (minerr >= 0.5) break;
00081 
00082                         float a= log((1.0f-minerr)/minerr)/2;
00083                         alpha[best]+=a;
00084                         vector<float> d1=d;
00085                         float z = 0;
00086                         for (unsigned int i=0;i<label.size();i++)
00087                         {
00088                                 d1[i]=d[i]*exp(-a*label[i]*rec[best][i]);
00089                                 z+=d1[i];
00090                         }
00091                         for (unsigned int i=0;i<label.size();i++)
00092                         {
00093                                 d[i]=d1[i]/z;
00094                         }
00095                 }
00096                 return alpha;
00097         }
00098 };
00099 
00100 //The linear combination of weak classifiers i.e. the strong classifier
00101 
00102 template <class T>
00103 class MultiClassifier :public Classifier<T>
00104 {
00105 private:
00106         vector<float> weights;
00107         vector<Classifier<T>*> clsfrs;
00108 public:
00109         float score;
00110         MultiClassifier(vector<float> w, vector<Classifier<T>*> c)
00111         {       
00112                 this->weights = w;
00113                 this->clsfrs = c;
00114         }
00115   std::string get_name() const { return "MultiClassifier"; }
00116     int recognize(T& obj)
00117         {
00118                 float res=0;
00119                 for (unsigned int i=0;i<weights.size();i++)
00120                   if(weights[i]> 0) res+=weights[i]*clsfrs[i]->recognize(obj);
00121                 score=res;
00122                 if (res>=0) 
00123                         return 1;
00124                 else
00125                         return -1;
00126         }
00127 };
00128 
00129 // the utility function that tests a (strong) classifier over all the test data
00130 
00131 template <class T>
00132 void testClassifier(Classifier<T>* cls, vector<T*> data, vector<int> label, float & fpos, float & fneg)
00133 {
00134         int pos = 0, neg = 0;
00135         fpos=fneg=0;
00136         for (int i=0;i<label.size();i++)
00137         {
00138                 int rec = cls->recognize(*data[i]);
00139                 if (label[i]==1)
00140                 {
00141                         pos++;
00142                         if (rec!=1)
00143                                 fneg=fneg+1;
00144                 }
00145                 if (label[i]==-1)
00146                 {
00147                         neg++;
00148                         if (rec!=-1)
00149                                 fpos=fpos+1;
00150                 }
00151         }
00152         fpos = fpos/neg;
00153         fneg = fneg/pos;
00154 }
00155 
00156 #endif