#include "simulateJumps.h" #include "talRandom.h" #include "someUtil.h" #include simulateJumps::simulateJumps(const tree& inTree, const stochasticProcess& sp, const int alphabetSize) : simulateJumpsAbstract(inTree,sp,alphabetSize) { } simulateJumps::~simulateJumps() { } void simulateJumps::init() { //init the vector of waiting times. _waitingTimeParams.clear(); _waitingTimeParams.resize(_alphabetSize); int i, j; for (i = 0; i < _alphabetSize; ++i) { _waitingTimeParams[i] = -_sp.dPij_dt(i, i, 0.0); } //init _jumpProbs. //_jumpProbs[i][j] = Q[i][j] / -Q[i][i] _jumpProbs.clear(); _jumpProbs.resize(_alphabetSize); for (i = 0; i < _alphabetSize; ++i) { MDOUBLE sum = 0.0; _jumpProbs[i].resize(_alphabetSize); for (j = 0; j < _alphabetSize; ++j) { if (i == j) _jumpProbs[i][j] = 0.0; else { _jumpProbs[i][j] = _sp.dPij_dt(i, j, 0.0) / _waitingTimeParams[i]; } sum += _jumpProbs[i][j]; } if (! DEQUAL(sum, 1.0)){ string err = "error in simulateJumps::init(): sum probabilities is not 1 and equal to "; err+=double2string(sum); errorMsg::reportError(err); } } //init _orderNodesVec: a vector in which the branch lengths are ordered in ascending order _tree.getAllNodes(_orderNodesVec, _tree.getRoot()); sort(_orderNodesVec.begin(), _orderNodesVec.end(), simulateJumpsAbstract::compareDist); _nodes2JumpsExp.clear(); _nodes2JumpsProb.clear(); VVdouble zeroMatrix(getCombinedAlphabetSize()); for (i = 0; i < getCombinedAlphabetSize(); ++i) zeroMatrix[i].resize(getCombinedAlphabetSize(), 0.0); Vdouble zeroVector(getCombinedAlphabetSize(),0.0); for (i = 0; i < _orderNodesVec.size(); ++i) { string nodeName = _orderNodesVec[i]->name(); _nodes2JumpsExp[nodeName] = zeroMatrix; _nodes2JumpsProb[nodeName] = zeroMatrix; for (j=0; jdis2father(); MDOUBLE totalTimeTillJump = 0.0; int jumpsNum = 0; int curState = startState; int smallestBranchNotUpdatedSofar = 0; vector > jumpsSoFar(0); while (totalTimeTillJump < maxTime) { MDOUBLE avgWaitingTime = 1 / _waitingTimeParams[curState]; MDOUBLE nextJumpTime = totalTimeTillJump + talRandom::rand_exp(avgWaitingTime); //go over all branches that "finished" their simulation (shorter than nextJumpTime) and update with their _nodes2JumpsExp //with the jumps that occurred between the terminal Ids: startState-->curState for (int b = smallestBranchNotUpdatedSofar; b < _orderNodesVec.size(); ++b) { if (_orderNodesVec[b]->dis2father() > nextJumpTime) { smallestBranchNotUpdatedSofar = b; break; } string nodeName = _orderNodesVec[b]->name(); //update all the jumps that occurred along the branch int terminalState = getCombinedState(startState, curState); _totalTerminals[nodeName][terminalState]++; //update all longer branches with all jumps that occurred till now vector jumpsSoFarBool(getCombinedAlphabetSize(),false); for (int j = 0; j < jumpsSoFar.size(); ++j) { int combinedJumpState = getCombinedState(jumpsSoFar[j].first, jumpsSoFar[j].second); jumpsSoFarBool[combinedJumpState]=true; _nodes2JumpsExp[nodeName][terminalState][combinedJumpState] += 1; } for (int combined=0;combined(curState, nextState)); curState = nextState; ++jumpsNum; } } void simulateJumps::computeExpectationsAndPosterior(){ //scale _nodes2JumpsExp so it will represent expectations map::iterator iterExp = _nodes2JumpsExp.begin(); for (; iterExp != _nodes2JumpsExp.end(); ++iterExp) { string nodeName = iterExp->first; for (int termState = 0; termState < getCombinedAlphabetSize(); ++termState) { for (int jumpState = 0; jumpState < getCombinedAlphabetSize(); ++jumpState) { //(iter->second[termState][jumpState]) /= static_cast(iterNum); map::iterator iterTerm = _totalTerminals.find(nodeName); map::iterator iterProb = _nodes2JumpsProb.find(nodeName); if ((iterTerm==_totalTerminals.end()) || (iterProb==_nodes2JumpsProb.end())) { errorMsg::reportError("error in simulateJumps::runSimulation, unknown reason: cannot find nodeName in map"); } if ((iterTerm->second[termState]==0)){ //never reached these terminal states if ((iterExp->second[termState][jumpState]==0) && (iterProb->second[termState][jumpState]==0)){ if( termState == jumpState && (getStartId(termState)!=getEndId(termState) ) ){ (iterExp->second[termState][jumpState]) = 1; // E.g - given start=0 end=1 there was at least one 0->1 jump (iterProb->second[termState][jumpState]) = 1; // E.g - given start=0 end=1 there was at least one 0->1 jump } continue;//leave the value of _nodes2JumpsExp and _nodes2JumpsProb as zero (or one) } else { errorMsg::reportError("error in simulateJumps::runSimulation, 0 times reached termState but non-zero for jumpCount"); } } (iterExp->second[termState][jumpState]) /= iterTerm->second[termState]; (iterProb->second[termState][jumpState]) /= iterTerm->second[termState]; } } } } MDOUBLE simulateJumps::getExpectation(const string& nodeName, int terminalStart, int terminalEnd, int fromId, int toId) { map ::iterator pos; if ((pos = _nodes2JumpsExp.find(nodeName)) == _nodes2JumpsExp.end()) { string err="error in simulateJumps::getExpectation: cannot find node "+nodeName; errorMsg::reportError(err); } int combinedTerminalState = getCombinedState(terminalStart, terminalEnd); int combinedJumpState = getCombinedState(fromId, toId); return (pos->second[combinedTerminalState][combinedJumpState]); } MDOUBLE simulateJumps::getProb(const string& nodeName, int terminalStart, int terminalEnd, int fromId, int toId){ map ::iterator pos; if ((pos = _nodes2JumpsProb.find(nodeName)) == _nodes2JumpsProb.end()) { string err="error in simulateJumps::getProb: cannot find node "+nodeName; errorMsg::reportError(err); } int combinedTerminalState = getCombinedState(terminalStart, terminalEnd); int combinedJumpState = getCombinedState(fromId, toId); return (pos->second[combinedTerminalState][combinedJumpState]); }