from z3 import *
s = Solver()

# I'll define a few functions that will make the code more concise
def member_test(key, A, lo, hi, i):
  return Exists(i, And(lo <= i, i < hi, A[i] == key))

def non_decreasing(A, lo, hi, i, j):
  return ForAll((i, j),
	        Implies(And(lo <= i, i <= j, j < hi),
		        A[i] <= A[j]))

def search_invariant(key, A, lenA, lo, hi, i, j):
  return And(non_decreasing(A, 0, lenA, i, j),
             0 <= lo, lo <= hi, hi <= lenA,
             Implies(member_test(key, A, 0, lenA, i),
                     member_test(key, A, lo, hi, j)))
  

# lines 0, 1, and 2 are just like in bf0.py and bf1.py
# 0: def search(key, A):
key = Int('key')
A = Array('A', IntSort(), IntSort())
lenA = Int('lenA')
s.add(lenA >= 0)

# 1: lo = 0
lo = Int('lo')
s.add(lo == 0)

# 2: hi = len(A)
hi = Int('hi')
s.add(hi == lenA)
i,j = Ints(('i', 'j'))
s.add(non_decreasing(A, 0, lenA, i, j))

# check that the invariant holds
s.push()
s.add(Not(search_invariant(key, A, lenA, lo, hi, i, j)))
ch = s.check()
if(ch == unsat):
  print "Test passed: Inv holds when execution first reaches the loop"
elif(ch == sat):
  print "Test failed: Inv does not hold when execution first reaches the loop"
  print s.model()
else:
  print "Test failed: z3 gave up trying to show that Inv holds when execution first reaches the loop"
s.pop()  # good karma

s2 = Solver()
s2.add(search_invariant(key, A, lenA, lo, hi, i, j))

# check the loop body
s2.push()
s2.add(lo < hi)  # the loop test passed
mid = Int('mid')
s2.add(mid == (lo + hi)/2)

#5: if(A[mid] == key): return mid}
s2.push()
s2.add(A[mid] == key)

s2.push()  # check the return value
s2.add(Not(A[mid] == key))
ch = s2.check()
if(ch == unsat):
  print "Test passed: key found correctly"
elif(ch == sat):
  print "Test failed: bad return value"
  print s2.model()  # find out why
else:
  print "Test failed: z3 gave up on checking the return value"
s2.pop() # end check of return value
s2.pop() # end branch for if(A[min] == key)

s2.push() # now look at the other two cases
s2.add(Not(A[mid] == key))

#6-7: update lo and hi
lo_next = If(A[mid] < key, mid+1, lo)
hi_next = If(A[mid] < key, hi, mid)

s2.push() # check that the invariant still holds with lo_next and hi_next
s2.add(Not(search_invariant(key, A, lenA, lo_next, hi_next, i, j)))
ch = s2.check()
if(ch == unsat):
  print "Test passed: invariant holds at the end of the loop body"
elif(ch == sat):
  print "Test failed: invariant does not hold at the end of the loop body"
  print s2.model()
else:
  print "Test failed: z3 gave up on checking the invariant at the end of the loop body"
s2.pop() # end of invariant check

# check termination: show a progress function that decreases at each iteration
s2.push()
s2.add(Not(And((hi_next - lo_next) <= (hi - lo) - 1, (hi_next - lo_next) >= 0)))
ch = s2.check()
if(ch == unsat):
  print "Test passed: search terminates"
elif(ch == sat):
  print "Test failed: search might not terminate"
  print s2.model()
else:
  print "Test failed: z3 gave up on checking termination"
s2.pop() # end of termination check

s2.pop() # end of check for Not(A[mid] == key) branches
s2.pop() # end of check for while-loop body

s2.push() # now check the code afer the while loop.
s2.add(Not(lo < hi))  # loop exit condition
s2.push() # check return value
# we are going to return None.  It's an error if key is in A
s2.add(member_test(key, A, 0, lenA, i))
ch = s2.check()
if(ch == unsat):
  print "Test passed: 'return None' means key is not in A"
elif(ch == sat):
  print "Test failed: might return None even though key is in A"
  print s2.model()
else:
  print "Test failed: z3 gave up checking the 'return None' case"
