from z3 import *
from puzzle_utils import one_to_one_fun;

# define the sets used in the puzzle
Nationality, (Brazilian, English, French, Greek, Spanish) = \
  EnumSort("Nationality", ("Brazilian", "English", "French", "Greek", "Spanish"))

Cargo, (cocoa, coffee, corn, rice, tea) = \
  EnumSort("Cargo", ("cocoa", "coffee", "corn", "rice", "tea"))

Chimney, (black, blue, green, red, white) = \
  EnumSort("Chimney", ("black", "blue", "green", "red", "white"))

City, (Hamburg, Genoa, Manila, Marseille, PortSaid) = \
  EnumSort("City", ("Hamburg", "Genoa", "Manila", "Marseille", "Port Said"))

Departs = range(5, 10)

Pier = range(1, 6)

# define one-to-one functions that maps each of the other sets to Pier
nationality, nat_inv, f1 = one_to_one_fun("nationality", Nationality, Pier)
cargo,       car_inv, f2 = one_to_one_fun("cargo", Cargo, Pier)
chimney,     chi_inv, f3 = one_to_one_fun("chimney", Chimney, Pier)
city,        cit_inv, f4 = one_to_one_fun("city", City, Pier)
departs,     dep_inv, f5 = one_to_one_fun("departs", Departs, Pier)
clues = [And(f1, f2, f3, f4, f5)]

def next_to(pier1, pier2):
  return Or(pier1 == pier2-1, pier1 == pier2+1)

# the clues
clues.append(And(departs(6) == nationality(Greek), cargo(coffee) == nationality(Greek)))
clues.append(3 == chimney(black))
clues.append(departs(9) == nationality(English))
clues.append(And(chimney(blue) == nationality(French), nationality(French) == cargo(coffee) - 1))
clues.append(cargo(cocoa)+1 == city(Marseille))
clues.append(nationality(Brazilian) == city(Manila))
clues.append(next_to(cargo(rice), chimney(green)))
clues.append(city(Genoa) == departs(5))
clues.append(And(nationality(Spanish) == departs(7), nationality(Spanish) == city(Marseille)+1))
clues.append(chimney(red) == city(Hamburg))
clues.append(next_to(departs(7), chimney(white)))
clues.append(Or(cargo(corn) == 1, cargo(corn) == 5))
clues.append(chimney(black) == departs(8))
clues.append(next_to(cargo(corn), cargo(rice)))
clues.append(city(Hamburg) == departs(6))

s = Solver()
s.add(And(*clues))
s.check()
m = s.model()
for p in range(1,6):
  print "pier=" + str(p) + ":",  m.eval(nat_inv(p)),  m.eval(car_inv(p)),  m.eval(chi_inv(p)),  m.eval(cit_inv(p)),  m.eval(dep_inv(p))
