from z3 import *
from z3mat import *
import numpy

A = z3matrix('A', 3)
x = z3matrix('x', (3,1))

def mat_eq(m1, m2):
  """mat_eq(m1, m2): true iff m1 and m2 are element-wise equal.
    Raises IndexError if m1 and m2 have different shapes.
  """
  return z3matall(matmap2(lambda e1, e2: e1 == e2, m1, m2))

print 'Ready to prove (x.T * A.T).T == A*x'
prove(mat_eq((x.T * A.T).T, A*x))

# show the equality fails if we forget to transpose A
prove(mat_eq((x.T * A).T, A*x))

# to understand the counter-example, let's put it back
# into numpy form

def get_cex(p):
  """If p is a theorem, return 'proved'
    If we can refute p, return a model for Not(p)
    If Z3 fails, return 'unknown'
  """
  s = Solver()
  s.add(Not(p))
  ch = s.check()
  if(ch == unsat): return  'proved'
  elif(ch == sat): return s.model()
  else: return 'unknown'

m = get_cex(mat_eq((x.T * A).T, A*x))
A_cex = z3matFromModel(A, m)
x_cex = z3matFromModel(x, m)
print 'A_cex = ' + str(A_cex)
print 'x_cex = ' + str(x_cex)
print '(x_cex.T * A_cex).T = ' + str((x_cex.T * A_cex).T)
print 'A_cex * x_cex = ' + str(A_cex * x_cex)

# A freebie -- your reward for reading this far
def symmetric(A):
  """symmetric(A): return an expression that is satisfied iff A is a
    symmetric matrix.  symmetric(A) fails with IndexError if A is not a
    square matrix.
  """
  try:
    sh = A.shape
    if(sh[0] != sh[1]):
      raise IndexError('symmetric(A): A is not a square matrix')
    return And(*[A[i,j] == A[j,i] for i in range(sh[0]) for j in range(i)])
  except Exception:
    raise IndexError('symmetric(A): A is a matrix')
