/* Code written by Brad Penoff (penoff@cs.ubc.ca) 
 *
 *  This code illustrates the communication pattern for
 *   an MPI application that could demonstrate 
 *   head-of-line blocking when loss occurs.
 *
 *  As distrubted, it outputs times for all ranks.  The interesting 
 *   thing becomes the range of results.
 */ 

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <mpi.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/time.h>
#include <assert.h>

 
/* These can be tweaked however you so desire */
#define NUM_ITERATE 30
#define NUMBER_OF_STREAMS 10
#define MSG_SIZE 60000
#define COMPUTE_TIME 1000000


// global variables
MPI_Request req_array[NUM_ITERATE * 2][NUMBER_OF_STREAMS]; //cannot reuse
MPI_Status status[NUMBER_OF_STREAMS]; //can reuse per block
char buf_array[NUMBER_OF_STREAMS][MSG_SIZE + 1];  //can reuse per block


// methods

void compute(int time) {
  int i = time * 1000;
  for(;i;i--);
}

void send_block(int to_rank, int iteration) {
  int i, j;
  MPI_Status statusA;

  compute(COMPUTE_TIME);
  for(i = 0; i < NUMBER_OF_STREAMS; i++)
    MPI_Isend(buf_array[i], MSG_SIZE, MPI_BYTE, to_rank,
                i, MPI_COMM_WORLD, &req_array[iteration][i]);
/*    MPI_Send(buf_array[i], MSG_SIZE, MPI_BYTE, to_rank,
      i, MPI_COMM_WORLD);*/

  for(i = 0; i < NUMBER_OF_STREAMS; i++) {
    MPI_Waitany(NUMBER_OF_STREAMS, req_array[iteration], &j, &statusA);
    compute(COMPUTE_TIME);
  }
}

void receive_block(int from_rank, int iteration) {
  int i, j;
  MPI_Status statusA;

  for(i = 0; i < NUMBER_OF_STREAMS; i++)
    /*    MPI_Recv(buf_array[i], MSG_SIZE, MPI_BYTE, from_rank,
	  i, MPI_COMM_WORLD, &statusA);*/
    MPI_Irecv(buf_array[i], MSG_SIZE, MPI_BYTE, from_rank,
	      i, MPI_COMM_WORLD, &req_array[iteration][i]);
  for(i = 0; i < NUMBER_OF_STREAMS; i++) {
    MPI_Waitany(NUMBER_OF_STREAMS, req_array[iteration], &j, &statusA);
    compute(COMPUTE_TIME);
  }
}

int main (int argc, char *argv[]) {
  int size, k=0, my_rank, m, n;
  double start, end, end2;

  for(m=0; m < NUMBER_OF_STREAMS; m++)
    for(n=0; n < MSG_SIZE; n++)
      buf_array[m][n] = 'b';
  for(m=0; m < NUMBER_OF_STREAMS; m++)
    buf_array[m][MSG_SIZE] = '\0';


  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

  start = MPI_Wtime();

  while(k < NUM_ITERATE) {
    //odd guy sends, even guy receives 
    if(my_rank % 2) {
        receive_block(my_rank - 1, (k*2));
    } else {
        if(my_rank == size - 1)
          ; //do nothing, odd number of nodes
        else
          send_block(my_rank + 1, (k*2));
    }
    MPI_Waitall(NUMBER_OF_STREAMS, req_array[k*2], status);  //just in case
//    MPI_Barrier(MPI_COMM_WORLD);

    //odd guy receives, even guy sends
    if(my_rank % 2) {
        send_block(my_rank - 1, (k*2)+1);
    } else {
        if(my_rank == size - 1)
          ; //do nothing, odd number of nodes
        else
          receive_block(my_rank + 1, (k*2)+1);
    }
    MPI_Waitall(NUMBER_OF_STREAMS, req_array[(k*2)+1], status);  //just in case
//    MPI_Barrier(MPI_COMM_WORLD);

//    if((my_rank == 0) && ((k % (NUM_ITERATE/10)) == 0))
//      printf("iteration %d\n", k);
    k++;
  }

  end = MPI_Wtime();
  MPI_Barrier(MPI_COMM_WORLD);
  end2 = MPI_Wtime();
//  if(my_rank == 0)
    printf("Total time before long for %d iterations on grank %d is %f seconds.\n",
				k, my_rank, end - start);
    printf("Total time after long barrier for %d iterations on grank %d is %f seconds.\n",
				k, my_rank, end2 - start);
  MPI_Finalize();
  return(0);
}

