#ifndef __WeightUpdate_h #define __WeightUpdate_h //#include //#include //CHANGE #include //change using namespace std; //change /** This class is an annealing weight update interface. It keeps record of the original weights and previous weights and returns a new weight vector given a vector of gradients for the weights. Update is governed by the temperature */ class tWeightUpdate { public: /// default constructor tWeightUpdate(const vector& origWeights) : _origWeights(origWeights), _prevWeights(origWeights), _temp(0.0) {}; /// default destructor virtual ~tWeightUpdate() {}; /// update the weights using gradients virtual vector const& UpdateWeights(double newtemp,vector const& gradient) = 0; /// return the last weight vector vector const& CurrentWeights() { return _prevWeights; }; /// return original weight vector vector const& OrigWeights() { return _origWeights; }; /// double GetTemp() { return _temp; }; /// set the temperature void SetTemperature(double temp); protected: /// get non const reference to prev weights vector& PrevWeights() { return _prevWeights; }; private: vector _origWeights; vector _prevWeights; double _temp; }; /** This class implement the interface of tWeightUpdate by random updating of weights centered around the original weights in proportion to the temperature. */ class tWeightUpdateRandom : public tWeightUpdate { public: /// default constructor tWeightUpdateRandom(const vector& origWeights) : tWeightUpdate(origWeights) {}; /// default destructor virtual ~tWeightUpdateRandom() {}; /// update the weights ignoring gradients virtual vector const& UpdateWeights(double newtemp,vector const& gradient); private: }; /** This class implement the interface of tWeightUpdate by doing a gradient update that is governed by the temperature and damping factors for the original and previous weights. The learning rate determines the magnitude of update. */ class tWeightUpdateGradient : public tWeightUpdate { public: /// default constructor tWeightUpdateGradient(const vector& origWeights) : tWeightUpdate(origWeights), _lRate(1.0), _origDamping(1.0), _prevDamping(1.0) {}; /// default destructor virtual ~tWeightUpdateGradient() {}; /// update the weights using gradients virtual vector const& UpdateWeights(double newtemp,vector const& gradient); /// set the learning rate for multiplicative update void SetLearningRate(double lRate) { _lRate = lRate; }; /// set damping factor with respect to original weights void SetOrigDamping(double origDamping) { _origDamping = origDamping; }; /// set damping factor with respect to previous weights void SetPrevDamping(double prevDamping) { _prevDamping = prevDamping; }; private: double _lRate; double _origDamping; double _prevDamping; }; #endif