/* SI 413 Fall 2012
 * Recursive descent parser and interpreter
 * for a simple calculator language.
 */

#include <iostream>
#include <cstdlib>
#include "bisoncalc.tab.hpp"
using namespace std;

//-- Prototypes and globals
int yylex();

void stmt(); 
int exp();  int exptail(int lhs); 
int term(); int termtail(int lhs); 
int factor(); 
int next = -1;
YYSTYPE nextval;

//-- Helper functions
void perror(const char* nt) { 
  cerr << "Parse error in " << nt << endl; 
  exit(1); 
}

int peek() { 
  if (next == -1) {
    next = yylex(); 
    nextval = yylval;
  }
  return next; 
}

// Returns the value of the matched token
YYSTYPE match(int tok) { 
  if (tok == peek()) {
    next = -1; 
    return nextval;
  }
  else perror("match"); 
}

//-- Grammar rule functions
void stmt() { 
  cout << exp() << endl; 
  match(STOP); 
}

int exp() { 
  int first = term();
  return exptail(first); 
}

int exptail(int lhs) {
  switch(peek()) {
    case OPA: 
      if (match(OPA).sym == '+') {
        int next = term(); 
        return exptail(lhs + next);
      }
      else {
        int next = term();
        return exptail(lhs - next); 
      }
      break;
    case RP: case STOP: 
      return lhs; break;
    default: 
      perror("exptail"); break;
  }
}

int term() { 
  int first = factor(); 
  return termtail(first); 
}

int termtail(int lhs) {
  switch(peek()) {
    case OPM: 
      if (match(OPM).sym == '*') {
        int next = factor(); 
        return termtail(lhs * next); 
      }
      else {
        int next = factor();
        return termtail(lhs / next);
      }
      break;
    case RP: case STOP: case OPA: 
      return lhs; 
      break;
    default: 
      perror("term"); 
      break;    
  }
}

int factor() {
  int val;
  switch(peek()) {
    case NUM: 
      return match(NUM).val; break;
    case LP: 
      match(LP); 
      val = exp(); 
      match(RP); 
      return val;
      break;
    default: perror("factor"); break;
  }
}

int main() { 
  while(true) {
    cout << "> " << flush;
    if (peek() == 0) break;
    stmt();
  }
  cout << endl;
  return 0;
}

