testmnist.cpp

File testmnist.cpp provides a little more complicated example of using the class. The program runs over the entire MNIST test dataset and calculates error rate.

00001 /*****************************************************************************
00002  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. By
00003 downloading, copying, installing or using the software you agree to this
00004 license. If you do not agree to this license, do not download, install, copy or
00005 use the software.
00006 
00007 Contributors License Agreement
00008 
00009 Copyright© 2007, Akhmed Umyarov. All rights reserved.
00010 
00011 Redistribution and use in source and binary forms, with or without modification,
00012 are permitted provided that the following conditions are met:
00013 - Redistributions of source code must retain the above copyright notice, this
00014 list of conditions and the following disclaimer.
00015 - Redistributions in binary form must reproduce the above copyright notice, this
00016 list of conditions and the following disclaimer in the documentation and/or
00017 other materials provided with the distribution.
00018 - The name of Contributor may not be used to endorse or promote products derived
00019 from this software without specific prior written permission.
00020 
00021 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
00022 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00023 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00024 DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00025 INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00026 BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00027 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00028 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
00029 OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
00030 ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00031 All information provided related to future Intel products and plans is
00032 preliminary and subject to change at any time, without notice.
00033 *****************************************************************************/
00034 
00041 #include "cvconvnet.h"
00042 #include <iostream>
00043 #include <fstream>
00044 #include <sstream>
00045 #include <exception>
00046 
00047 using namespace std;
00048 
00054 int main(int argc, char *argv[])
00055 {
00056         if (argc <=1 )
00057         {
00058                 cerr << "Usage: " << endl << "\ttestmnist <network.xml>" << endl;
00059                 return 1;
00060         }
00061 
00062         // Create empty net object
00063         CvConvNet net;
00064 
00065         // Load mnist.xml file into string
00066         ifstream ifs( argv[1] );
00067         string xml ( (istreambuf_iterator<char> (ifs)) , istreambuf_iterator<char>() );
00068         
00069         // Create network from XML string
00070         if ( !net.fromString(xml) )
00071         {
00072                 cerr << "*** ERROR: Can't load net from XML string" << endl;
00073                 return 1;
00074         }
00075 
00076         // Represent MNIST datafiles as C++ file streams f1 and f2 respectively
00077         ifstream f1("t10k-images-idx3-ubyte",ios::in | ios::binary); // image data
00078         ifstream f2("t10k-labels-idx1-ubyte",ios::in | ios::binary); // label data
00079         
00080         if (!f1.is_open() || !f2.is_open())
00081         {
00082                 cerr << "ERROR: Can't open MNIST files. Please locate them in current directory" << endl;
00083                 return 1;
00084         }
00085         // Create buffers for image data and correct labels
00086         const int BUF_SIZE = 2048;
00087         char *buffer = new char[BUF_SIZE];
00088         char *label = new char[2];
00089 
00090         // Block for catching file exceptions
00091         try
00092         {
00093                 // Read headers
00094                 f1.read(buffer,16);
00095                 f2.read(buffer,8);
00096         
00097                 // Here is our info
00098                 int imgno = 10000; // 10'000 images in file
00099                 int imgheight = 28; // image size
00100                 int imgwidth = 28;
00101                 int imgpadx = 2; // Pad images by 2 black pixels, so
00102                 int imgpady = 2; // the image becomes 32x32
00103                 int imgpaddedheight = imgheight+2*imgpady; // padded image size
00104                 int imgpaddedwidth = imgwidth+2*imgpadx;
00105                 
00106                 // Prepare image structures
00107                 IplImage *img32 = cvCreateImageHeader( cvSize(imgpaddedheight,imgpaddedwidth), IPL_DEPTH_8U, 1 );
00108         
00109                 // imageData now points to our buffer
00110                 img32->imageData = buffer;
00111         
00112                 // Clean the buffer
00113                 memset(buffer,0,BUF_SIZE);
00114         
00115                 // Initialize error counter
00116                 int errors = 0;
00117         
00118                 // Now cycle over all images in MNIST test dataset
00119                 for (int i=0; i<imgno; i++)
00120                 {
00121                         // Load the image from file stream into img32
00122                         // (remember img32->imgData points to our buffer)
00123                         for (int k=0; k<imgheight; k++)
00124                         {
00125                                 // Image in file is stored as 28x28, so we need to pad it to 32x32
00126                                 // So we read the image row-by-row with proper padding adjustments
00127                                 f1.read(&buffer[imgpadx+(imgpaddedwidth)*(k+2)],imgwidth);
00128                         }
00129         
00130                         // Propagate the matrix through network and get the result
00131                         int pos = (int) net.fprop(img32);
00132                         
00133                         // Now read the correct label from label file stream
00134                         f2.read(label,1);
00135         
00136                         // Check if our prediction is correct
00137                         if ( label[0]!=pos ) errors++;
00138                 }
00139                 
00140                 // Print the error rate
00141                 cout << "Error rate: " << (double)100.0*errors/imgno << "%" << endl;
00142                 
00143         } catch (exception &e)
00144         {
00145                 cerr << "Exception: " << e.what() << endl;
00146         }
00147 
00148 
00149         // Don't forget to free the memory
00150         delete[] label;
00151         delete[] buffer;
00152 
00153         // That's it!
00154         return 0;
00155 }

Generated on Fri Aug 3 16:17:27 2007 for ConvNet by  doxygen 1.5.0