# Program Name: piS
# it so that it calculates pi (the mean pairwise difference) and
# S, the number of segregating sites.
#
# The format for a *.seq file is like this:
# sequence_name1 SEQUENCE1
# sequence_name2 SEQUENCE2
#       for example,
# Comments are delimited by sharp signs, as in the file you are now
# reading.
# Line 1: Contains number of samples, length of each sequence
# Remaining lines have two fields: 
#    Field 1: name of sequence
#    Field 2: the sequence
# Example:
#   hiv1 GACTGACTGATCG
#   hiv2 GCATGCTAGCTAG
#
# Caveats: Sequences need - to be aligned
#                         - to be the same length
#                         - to be on one line (lines can run over the edge, like this one).
#
BEGIN {
    ngenes = 0;
}

NF==0 {
# skip empty lines
  next
}
NR==1 {
  # read the number of genes and of sites
  G = $1;                # number of genes
  nsites = $2;           # number of sites
}
NR>1  {
  if(NF != 2){
    printf("ERROR: line %d has %d fields (should have 2)\n",
	   NR, NF);
    printf("bad line: \"%s\"\n", $0);
    exit;
  }
  # printf("%d: [%s] [%s]\n", NR, $1, $2);
  # put sequences into an array
  ++ngenes;
  sequence[ngenes]=$2;
}

END{
  # Did we read the right number of sequences?
  if(ngenes != G){
    printf("ERROR read %d sequences; expected %d\n",
	   ngenes, G);
    exit;
  }
  
  # Check for uneven sequence length
  for (i=1; i<=G; i++){
    if (length(sequence[i]) != nsites){
      printf("ERROR - length(%d)=%d; nsites=%d\n",
	     i, length(sequence[i]), nsites);
      exit;
    }
  }
  
  # Using the first sequence as a reference, see if any sites
  # in each other sequence are different. If a site is different,
  # mark it in the varsites array. In the varsites array, any
  # cell containing a value > 0 is variable.
  split(sequence[1], refseq, "");
  for (i=2; i<=G; i++){
    split(sequence[i], compseq, "");
    for (j=1; j<=nsites; j++){
      if (refseq[j] != compseq[j]) 
        varsites[j] = 1;
    }
  }

  # Count segregating sites.
  for (i=1; i<=nsites; i++){
    S += varsites[i];
  }

  # Count differences between each pair of sequences
  printf("Doing %d pairwise comparisons...be patient\n", G*(G-1)/2);
  for(i=1; i<=G; i++){
    printf("%2d:\r", i);
    split(sequence[i], seqi, "");
    for(j=1; j<i; j++){
      split(sequence[j], seqj, "");
      for(k=1; k<=nsites; k++)
	if(seqi[k] != seqj[k])
	   pi++;
    }
  }
  # divide by the number of pairwise comparisons
  pi /= G*(G-1)/2;

  # divisor for estimating theta from S
  a1 = a2 = 0;
  for(i=1; i <= (G-1); i++) {
    a1 += 1/i;
    a2 += 1/(i*i);
  }
  thetaS = S/a1;

  printf("%d sequences, %d sites\n", G, nsites);
  printf("                         pi: %f \n", pi);
  printf("          Segregating sites: %d/%d\n", S, nsites);
  printf("theta_hat[estimated from S]: %f\n", thetaS);


  # Calculate Tajima's D:
  b1 = (G+1.0) / (3.0*(G-1.0));
  b2 = 2.0*(G^2 + G + 3)/(9.0*G*(G-1));
  c1 = b1 - 1.0/a1;
  c2 = b2 - (G+2)/(a1*G) + a2/(a1*a1);
  e1 = c1/a1;
  e2 = c2/(a1*a1 + a2);
  D = (pi - thetaS)/sqrt(e1*S + e2*S*(S-1));
  printf("                 Tajima's D: %f\n", D);
  printf("a1=%f a2=%f b1=%f b2=%f\n", a1, a2, b1, b2);
  printf("c1=%f c2=%f e1=%f e2=%f\n", c1, c2, e1, e2);
}