from z3 import *
s = Solver()

n_unroll = 5;
lenA_max = 8;

# 0: def search(key, A):
key = Int('key')
A = Array('A', IntSort(), IntSort())
lenA = Int('lenA')
s.add(lenA >= 0)

# 1: lo = 0
lo = Int('lo')
s.add(lo == 0)

# 2: hi = len(A)
hi = Int('hi')
s.add(hi == lenA)
i,j = Ints(('i', 'j'))
s.add(ForAll((i, j),
	     Implies(And(0 <= i, i <= j, j < lenA),
		     A[i] <= A[j])))

for m in range(n_unroll):
  # case: exit the while loop after 0 iterations
  s0x = Solver()           # make a new solver
  s0x.add(s.assertions())  # clone the constraintss0b.add(Not(A[mid] == key))
  s0x.add(Not(lo < hi))    # we've exited the loop

  # If key is in the array, we should report an error.
  s0x.push()  # try again
  s0x.add(Exists(i, And(0 <= i, i < lenA, A[i] == key)))
  ch = s0x.check()
  if(ch == unsat):
    print "Test passed: while-loop exits after", str(m), "iterations."
  elif(ch == sat):
    print "Test failed: while-loop exits after", str(m), "iterations, but key is in A."
    print s0x.model()  # find out why
  else:
    print "Test failed for while-loop exit after", str(m), "iterations.  z3 gave up."
  s0x.pop()   # restore balance


  # case: execute the loop body for the first time
  s.add(lo < hi)  # the loop test passed

  #4: mid = (lo + hi)/2}
  mid = Int('mid$' + str(m))
  s.add(mid == (lo + hi)/2)

  #5: if(A[mid] == key): return mid}
  s.push()
  s.add(A[mid] == key)

  # If our return value is not the index of an element equal to key
  #   report an error
  s.add(Not(A[mid] == key))
  ch = s.check()
  if(ch == unsat):
    print "Test passed: found key on iteration", str(m), "of the while loop"
  elif(ch == sat):
    print "Test failed: returned an invalid index on iteration", str(m), "of the while loop"
    print s.model()  # find out why
  else:
    print "Test failed: z3 gave up on iteration", str(m), "of the while loop"
  s.pop()

  s.add(Not(A[mid] == key))  # the condition for the first if is false
  lo_next, hi_next = Ints(('lo$' + str(m+1), 'hi$' + str(m+1)))
  s.add(lo_next == If(A[mid] < key, mid, lo))
  s.add(hi_next == If(A[mid] < key, hi,  mid))
  lo = lo_next
  hi = hi_next

# ok, we survived n_unroll iterations.
# make sure that if lenA < lenA_max, then we can't unroll this farr
s.push()
s.add(lenA < lenA_max)
ch = s.check()
if(ch == unsat):
  print "Test passed: search returns within", str(n_unroll), "iterations if len(A) < ", str(lenA_max)
elif(ch == sat):
  print "Test failed: search didn't return when expected"
  print s.model()
else:
  print "Test failed: can't show termination.  z3 gave up."
