//                                Cmmtest                             

//  Robin Morisset, ENS & INRIA Paris-Rocquencourt
//  Pankaj Pawan, IIT Kanpur & INRIA Paris-Rocquencourt  
//  Pankaj Prateek, IIT Kanpur & INRIA Paris-Rocquencourt  
//  Francesco Zappa Nardelli, INRIA Paris-Rocquencourt

// The Cmmtest tool is copyright 2012, 2013 Institut National de
// Recherche en Informatique et en Automatique (INRIA).

// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions 
// are met: 
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. The names of the authors may not be used to endorse or promote
// products derived from this software without specific prior written
// permission.

// THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS 
// OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
// ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY 
// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 
// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
// IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
// IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

bool debug = false;
#define STACK_MASK 0x700000000000
set<int> used_index;
vector<int> irr_index;
map<ADDRINT, vector<int> >access_table;

unordered_map<unsigned long, vector<int> > loads2stores;
bool ignore_reads_in_jumps = false; 
//ignore the variables reads for jumps in function of type "safe_"

VOID warning(string &s){
  cout << s << endl; 
}

static bool checkIfMfence(ADDRINT instr_addr){
  //TODO : addr Lfence ?
  unsigned short temp = *(unsigned short*)instr_addr;
  if (temp == 0xae0f)
    return true;
  return false;
}

INT64 indexCount = 0;//number of dynamically executed global memory accesses 

VOID dump_access_table() {
  map<ADDRINT, vector<int> >::iterator im;
  cout << "--TABLE DUMP--" << endl;
  vector<int>temp;
  for (im = access_table.begin(); im != access_table.end(); im++ ) {
    //cout << REG_StringShort((REG)(*im).first) << " : " << ((*im).second).front() << endl;
    if ((*im).first & STACK_MASK) {
      cout << hex << (*im).first << " : " << dec;
    }
    else
      cout << REG_StringShort((REG)((*im).first)) << " : " ;
    temp = (*im).second;
    for (unsigned int i = 0; i< temp.size();i++)
      cout << ",  " << temp[i] ; 
    cout << endl;
  }
  cout << "--------------" << endl;
  //if(INS_IsBranch(ins))cout << "branch instruction"<< endl;
  //analyse_instruction(ins);
}

// VOID irrelevant_reads_analysis(bool memory_read, bool memory_write, 
//     bool memory_read_write, ADDRINT memory_address, set<int> *read_set, 
//     set<int> *write_set, set<int>* read_write_set, bool branch_taken,
//     bool stack_read, bool stack_write, ADDRINT instr_addr, BOOL branch, 
//     BOOL rep_ins) {
VOID irrelevant_reads_analysis(bool memory_read, bool memory_write, 
    bool memory_read_write, ADDRINT memory_address, set<int> *read_set, 
    set<int> *write_set, set<int>* read_write_set, bool branch_taken,
    bool is_read, bool is_write, ADDRINT instr_addr, BOOL branch, 
    BOOL rep_ins) {

  map<ADDRINT, vector<int> >::iterator im;
  set<int>::iterator is;
  vector<int>::iterator iv;

  if (memory_read || memory_write || memory_read_write) {
    if(debug)
      cout << "\taddress computed " << hex << memory_address << endl;
  } else {
    memory_address = 0;
  }
  if (debug) {
    cout << "\t Read Set : ";
    for(is = (*read_set).begin();is != (*read_set).end() ; is++){
      cout << REG_StringShort((REG)*is) << " " ; 
    }
    cout << endl <<"\t Write Set : " ;
    for(is = (*write_set).begin();is != (*write_set).end() ; is++){
      cout << REG_StringShort((REG)*is) << " " ; 
    }
    cout << endl << "\t ReadWrite Set : " ;
    for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
      cout << REG_StringShort((REG)*is) << " " ; 
    }
    cout << endl;
  }

  if (debug) {
    cout << hex << instr_addr << dec << endl;
    cout << "\t" << " memR : " << memory_read << "; memW : " << memory_write << "; memRW : " << memory_read_write << endl;
    cout << dec ;
  }
  
  bool stack_read=false, stack_write=false;
  
  if(is_read) {
    UINT64 result = (UINT64)memory_address & STACK_MASK;
    if (result) stack_read=true;
    result = ( (UINT64)memory_address >> 20 ) ^ 0x4;
    if (!result) stack_read=true;
  }

  if(is_write) {
    UINT64 result = (UINT64)memory_address & STACK_MASK;
    if (result) stack_write=true;
    result = ( (UINT64)memory_address >> 20 ) ^ 0x4;
    if (!result) stack_write=true;
  }

  if(stack_read && debug)
    cout<<"\tStack Read"<<endl;
  if(stack_write && debug)
    cout<<"\tStack Write"<<endl;

  //  if ((memory_address & STACK_MASK) && (!stack_read && !stack_write)) {
  if ( ( (memory_address & STACK_MASK) && !(((UINT64)memory_address >> 20) ^ 0x4) ) && (!stack_read && !stack_write) ) {
    stack_read = memory_read;
    stack_write = memory_write; 
    if (memory_read_write) {
      stack_read = true ;
      stack_write = true;
    }
  }
  
  if (branch_taken || branch) {
    //FIXME : enters even if the branch is not taken
    //cout << "branch taken" << endl;
    //XXX : remove the reg in read_set from the access_table and marking the reads as useful
    //TODO : think of a more clever heuristic
    //FIXME :assert((*write_set).empty());

    if (ignore_reads_in_jumps) {
	/*DO NOTHING*/
	return;
    }
    for (is = (*read_set).begin();is != (*read_set).end() ; is++) {
      int reg = *is;
      //add all loads in dependency list to used_index
      if (debug)
        cout << "\t" << REG_StringShort((REG)reg) << endl;
      if (access_table.count(reg) > 0) {
        vector<int>temp = access_table[reg];
        for (iv = temp.begin(); iv != temp.end(); iv++) {
	  used_index.insert(*iv);
        }
        //assert(access_table.erase(reg) == 1);
      } else {
        if (debug)
          warning(REG_StringShort((REG)reg).append(" is being read but was not found in the access_table")); 
      }
    }
    return;
  }
  //TODO : think about more assertions
  //can't be two stores ??
  //assert(!(read_write && memory_write));

  //The source operand can be an immediate value, general-purpose register,
  //segment register, or memory location; the destination register can be a
  //general- purpose register, segment register, or memory location


  //Pankaj : Some assertions according to my understanding of the x86 instruction set
  assert(!(memory_read && memory_write) || rep_ins);
  //Pankaj : read_write_set is 2 for leave instruction, for everything else i know its <=1
  assert((*read_write_set).size() <= 2);
  //assert((*write_set).size() <= 2); 
  //assert(read_set.size() <= 2);
  assert((*read_set).size() + (*write_set).size() <= 4);
  //XXX: 4 is too loose, its only for cmpxchg, checking at different points would 
  //be better.

  //-----------------------
  //computer the loads on which the current instruction output depends
  //TODO: add the access_table[info->address] to the dependent loads
      
  set<int> dependant_loads;

  if (stack_read && access_table.count(memory_address)) {
    vector<int>temp = access_table[memory_address];
    for (iv = temp.begin(); iv != temp.end(); iv++) {
      dependant_loads.insert(*iv);
    }
  }
  for (is = (*read_set).begin();is != (*read_set).end() ; is++) {
    int reg = *is;
    //add all loads in dependency list to used_index
    if(access_table.count(reg) > 0){
      vector<int>temp = access_table[reg];
      for (iv = temp.begin(); iv != temp.end(); iv++) {
        dependant_loads.insert(*iv);
      }
    }
    else {
      if (debug)
        warning(REG_StringShort((REG)reg).append(" is being read but was not found in the access_tabl")); 
    }
  }
  for (is = (*read_write_set).begin();is != (*read_write_set).end() ; is++) {
    int reg = *is;
    //add all loads in dependency list to used_index
    if(access_table.count(reg) > 0){
      vector<int>temp = access_table[reg];
      for(iv = temp.begin(); iv != temp.end(); iv++){
        dependant_loads.insert(*iv);
      }
    }
    else{
      if (debug)
        warning(REG_StringShort((REG)reg).append(" is being read but was not found in the access_table")); 
      //maybe because it was considered used and removed before
    }
  }
  vector<int> dependent_loads_vector;
  if (debug)
    cout << "Dependent Loads : " ;
  for (is = dependant_loads.begin();is != dependant_loads.end() ; is++) {
    dependent_loads_vector.push_back(*is);
    if (debug)
      cout << *is << ",";
  }
  if(debug)cout << endl;

  ////--------------------------

  if (memory_read_write) {
    if (debug)
      cout << "memory_read_write" << endl;

    assert(!(memory_write || memory_read));
    if (!(*read_write_set).empty()) { 
      assert((*read_write_set).size() == 1);
      //assert(INS_IsXchg(ins));
      assert((*write_set).empty());
      //the memory operand may contain the address in a register where the exchange
      //will take place. Hence (*read_set) may not be empty.
      //for e.g  xchg qword ptr [rdx], rax ?? . 
      //Use base address register to capture rdx.
      if (!stack_write) {
        //it was a global memory readWrite
        //add indexCount to the dependent loads 
        dependant_loads.insert(indexCount);
        for (is = dependant_loads.begin();is != dependant_loads.end() ; is++) {
          used_index.insert(*is);
          loads2stores[*is].push_back(indexCount+1); //All loads in dependent sets are used to compute the value written in this store
        }
        //used_index.insert(indexCount);
        {
          /*Uncomment to prevent the dataflow...but now we want to find which 
            stores did this write affect, probably an option would be better?
            */
          //for(is = (*read_set).begin();is != (*read_set).end() ; is++){
          //  int reg = *is;
          //  //add all loads in dependency list to used_index
          //  if(access_table.count(reg) > 0){
          //    assert(access_table.erase(reg) == 1);    
          //  }
          //}
          //for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
          //  int reg = *is;
          ////add all loads in dependency list to used_index
          //  if(access_table.count(reg) > 0){
          //    assert(access_table.erase(reg) == 1);    
          //  }
          //}
        }
        dependent_loads_vector.push_back(indexCount);
        for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      }
      else {
        //don't add indexCount as it was a stack readWrite
        if(memory_address && dependent_loads_vector.size())
          access_table[memory_address] = dependent_loads_vector;
        for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      }
    }
    else{
      if(!stack_write){
        //it was a global memory readWrite
        //add indexCount to the dependent loads 
        dependant_loads.insert(indexCount);
        for(is = dependant_loads.begin();is != dependant_loads.end() ; is++){
          used_index.insert(*is);
          loads2stores[*is].push_back(indexCount+1); //All loads in dependent sets are used to compute the value written in this store
        }
        //used_index.insert(indexCount);
        {
          /*Uncomment to prevent the dataflow...but now we want to find which 
            stores did this write affect, probably an option would be better?
            */
          //for(is = (*read_set).begin();is != (*read_set).end() ; is++){
          //  int reg = *is;
          //  //add all loads in dependency list to used_index
          //  if(access_table.count(reg) > 0){
          //    assert(access_table.erase(reg) == 1);    
          //  }
          //}
          //for(is = (*write_set).begin();is != (*write_set).end() ; is++){
          //  reg = *is;
          //  //removing write_set entries from access_table as the loads they 
          //  //carry are already marked as used
          //  if(access_table.count(reg) > 0){
          //    assert(access_table.erase(reg) == 1);    
          //  }
          //}
        }
        dependent_loads_vector.push_back(indexCount);
        for (is = (*write_set).begin();is != (*write_set).end() ; is++) {
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if (dependent_loads_vector.size())
	    access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      } else {
        //don't add indexCount as it was a stack readWrite
        assert(memory_address & STACK_MASK);
        //access_table[memory_address] = dependent_loads_vector;
        if(memory_address && dependent_loads_vector.size())
          access_table[memory_address] = dependent_loads_vector;
        for(is = (*write_set).begin();is != (*write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      }
    }
  } else if (memory_read /*&& !(*read_write_set).empty()*/){

    if(debug)
      cout << "memory_read " << indexCount  << endl;

      //assert(!memory_read_write);
      //assert((*read_write_set).size() <= 2); //XXX : 2 only for leave instruction check again.
      //----------------
      //assert((*read_set).size() <= 1);
      //XXX : while removing any entry from the access table 
      //first search if any entry has the same indexCount, if yes then remove
      //that entry also, else it would result in a false positive
      //for e.g 
      //    mov eax <- [x]  //entry for eax
      //    sub eax, ebx    //update to cflag so entry of cflags with same index
      //    cmove ...        //depends on cflags and hence the load was relevant 
      //For this instruction eax and rflags both are updated so we will have an 
      //entry for rflags. Now while parsing the cmove both entries must be cleared!.

      if(!stack_read){
        if(debug)
          cout << "global memory read" << endl;
        dependent_loads_vector.push_back(indexCount);
        for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }

        for(is = (*write_set).begin();is != (*write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      }
      else {
        //FIXME : assert(access_table.count(memory_address));
        //stack read can be done only if its written before
        for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);

        }
        for(is = (*write_set).begin();is != (*write_set).end() ; is++){
          int reg = *is;
          if(access_table.count(reg) > 0){
            assert(access_table.erase(reg) == 1);    
          }
          if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
          else access_table.erase(reg);
        }
      }
  }
  else if(memory_write /*&& !(*read_set).empty()*/){
    if(debug)
      cout << "memory_write" << endl;  
    assert(!memory_read_write);
    //FIXME :assert((*read_write_set).empty());
    //Not even implicit operands are register writes? Haven't found any thing yet!
    //On more experiments, if there is a write to a register and a write to a memory
    //then the memory operand is read also!(covered in first case)
    //assert((*write_set).empty());  
    //what if the compiler introduces a read and a write to the same location,
    //it would kill the flow and the introduced read wont be reported
    //Assuming it doesn't for now.
    //can detect this by logging memory addresses too? 
    if(!stack_write){
      //it was a global memory Write
      for(is = dependant_loads.begin();is != dependant_loads.end() ; is++){
        used_index.insert(*is);
        loads2stores[*is].push_back(indexCount+1); //All loads in dependent sets are used to compute the value written in this store
      }
      //for(is = (*read_set).begin();is != (*read_set).end() ; is++){
      //  int reg = *is;
      //  //add all loads in dependency list to used_index
      //  if(access_table.count(reg) > 0){
      //    assert(access_table.erase(reg) == 1);    
      //  }
      //}
      //for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
      //  int reg = *is;
      //  //add all loads in dependency list to used_index
      //  if(access_table.count(reg) > 0){
      //    assert(access_table.erase(reg) == 1);    
      //  }
      //}
      for(is = (*write_set).begin();is != (*write_set).end() ; is++){
        int reg = *is;
        if(access_table.count(reg) > 0){
          assert(access_table.erase(reg) == 1);    
        }
        if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
        else access_table.erase(reg);
      }
      for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
        int reg = *is;
        if(access_table.count(reg) > 0){
          assert(access_table.erase(reg) == 1);    
        }
        if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
        else access_table.erase(reg);
      }
      //read_set should be as it is
    }
    else {
      //don't add indexCount as it was a stack readWrite
      assert(memory_address);

      if(dependent_loads_vector.empty())access_table.erase(memory_address);
      else access_table[memory_address] = dependent_loads_vector;
         
      //access_table[memory_address] = dependent_loads_vector;
      for(is = (*write_set).begin();is != (*write_set).end() ; is++){
        int reg = *is;
        if(access_table.count(reg) > 0){
          assert(access_table.erase(reg) == 1);    
        }
        if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
        else access_table.erase(reg);
      }
      for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
        int reg = *is;
        if(access_table.count(reg) > 0){
          assert(access_table.erase(reg) == 1);    
        }
        if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
        else access_table.erase(reg);
      }
    }
  } else if (!(*write_set).empty() || !(*read_set).empty() || !(*read_write_set).empty()) {
    if (debug)
      cout << "no memory accesses" << endl;
    assert ((memory_read || memory_write || memory_read_write) == 0); 
    for (is = (*write_set).begin();is != (*write_set).end() ; is++) {
        int reg = *is;
        //add all loads in dependency list to used_index
        if(access_table.count(reg) > 0){
          assert(access_table.erase(reg) == 1);    
        }
        if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
        else access_table.erase(reg);
    }
    for(is = (*read_write_set).begin();is != (*read_write_set).end() ; is++){
      int reg = *is;
      //add all loads in dependency list to used_index
      if(access_table.count(reg) > 0){
        assert(access_table.erase(reg) == 1);    
      }
      if(dependent_loads_vector.size())access_table[reg] = dependent_loads_vector;
      else access_table.erase(reg);
    }
  } else {
//    analyse_instruction(ins);
    if (debug) {
      cout << "MRW\tMR\tMW\tRRW\tRR\tRW"<<endl;
      cout << memory_read_write << "  \t" << memory_read << " \t" 
	   << memory_write << " \t" << (*read_write_set).size() 
	   << "  \t"<< (*read_set).size() << " \t" << (*write_set).size() << endl;
    }
    if (!checkIfMfence(instr_addr))
      assert(0);
  }
  //XXX : DO NOT free up the read/write sets
  if (debug) 
    dump_access_table();
}

/* ===================================================================== */
// Instrumentation callbacks
/* ===================================================================== */

/*!
 * Insert call to the CountBbl() analysis routine before every basic block 
 * of the trace.
 * This function is called every time a new trace is encountered.
 * @param[in]   trace    trace to be instrumented
 * @param[in]   v        value specified by the tool in the TRACE_AddInstrumentFunction
 *                       function call
 */

VOID analyse_instruction(INS ins){
  if(INS_IsNop(ins)) return; 

  if(UINT32 op = INS_OperandCount(ins)){
    int temp = 0, reg_read_count = 0, reg_write_count = 0, reg_read_write = 0;
    int mem_read = 0, mem_write = 0, read_write = 0;
    for(unsigned int i = 0; i < op ; i++){
      if(INS_OperandIsImplicit(ins, i)){
        temp++;
        cout << REG_StringShort (INS_OperandReg(ins, i))<< " " << INS_OperandReg(ins,i) << " " << REG_FullRegName(INS_OperandReg(ins, i)) << " " ;
        if(INS_OperandIsReg(ins, i)) {
          
          if(INS_OperandReadAndWritten(ins, i)){reg_read_write++;}
          else if(INS_OperandRead(ins, i))reg_read_count++;
          else reg_write_count++;
          cout << INS_OperandRead(ins, i) << " " << INS_OperandWritten(ins, i)<< " "; 
        }
      }
      else 
        if(INS_OperandIsReg(ins, i)){
          if(INS_OperandReadAndWritten(ins, i)){reg_read_write++;}
          else if(INS_OperandRead(ins, i))reg_read_count++;
          else reg_write_count++;
        }
        else if(INS_OperandIsMemory(ins, i)){
          if(INS_OperandReadAndWritten(ins, i)){read_write++;}
          else if(INS_OperandRead(ins, i))mem_read++;
          else mem_write++;
        }
    }
    cout << endl;
    cout << op - temp ;
    
    op = INS_MemoryOperandCount(ins);
    int stackR=0, stackW = 0;
    if(INS_IsStackRead(ins) ){
      op--;
      stackR = 1;
    }
    if(INS_IsStackWrite(ins) ){
      stackW = 1;
      op--;
    }
     cout << " MT :  " << op << " MRW : " << read_write<<  " MR : " << mem_read << " MW : " << mem_write << " stackR " << stackR << " SW " << stackW  ; 
    cout << " RRW " << reg_read_write << " RW "<<reg_write_count << "/"<< INS_MaxNumWRegs(ins) << " RR " << reg_read_count<<"/"<< INS_MaxNumRRegs(ins)   << endl;
    cout << "\t" << hex << INS_Address(ins) << " " << INS_Disassemble(ins) << endl;
  }
  
  return;
}

