import java.awt.*;
import java.awt.event.*;
//import java.applet.*;
import java.text.*;
import javax.swing.*;

/**
 * This applet demonstrates a simple grid-world game.  It is complex
 * enough so that simple, rule-based strategies are difficult to
 * write, but it is small enough so that state-based MDP and RL
 * algorihms can work. There are lots of ways it can be extended.
 * It isn't designed to be general or reusable.
 * <P>
 * Copyright (C) 2003-2006  <A HREF="http://www.cs.ubc.ca/spider/poole/">David Poole</A>.
 * <p>
 * This program gives the GUI. The GUI is in <A
 * HREF=SGameGUI.java">SGameGUI.java</A>. The environemnt code is at
 * <A HREF="SGameEnv.java">SGameEnv.java</A>. The controller is at <A
 * HREF="SGameController.java">SGameController.java</A>.
 * <p>
 This program is free software; you can redistribute it and/or
 modify it under the terms of the GNU General Public License
 as published by the Free Software Foundation; either version 2
 of the License, or (at your option) any later version.
<p>
 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.
<p>
 You should have received a copy of the GNU General Public License
 along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.


 * @author David Poole  poole@cs.ubc.ca
 * @version 0.41 2007-09-09 */



public class SGameGUI extends JApplet
{
    SGameController controller;

    SGameEnv environment;

    int xDim;
    int yDim;  //these get set from the environment model

    JTextField discountField;

    JTextField initialValueField;
    JTextField greedyField;
    JTextField stepField;
    JTextField alphaField;

    boolean includeAlphaCheck = true;
    String alphaText = "Fixed alpha =";
    final int strutSize=15;

    JLabel reportStepsLabel;
    JLabel reportRewardsLabel;
    JLabel reportMinLabel;
    JLabel reportZeroLabel;

    int sqsize = 110;
    int twid = 5;
    int fontSize = 14;
    Font myFont = new Font("SansSerif", Font.PLAIN, fontSize);

    boolean showCounts=false;
    boolean showCountsOption=true;

    DecimalFormat df = new DecimalFormat("0.##");
    Dimension gridDimension = new Dimension(sqsize*xDim,sqsize*yDim);

    GridPanel graphPanel;

    public void init()
    {
	String whichController = getParameter("controller");
	if (whichController.equalsIgnoreCase("Hand"))
	    controller= new SGameController();
	else if (whichController.equalsIgnoreCase("QLearning"))
	    controller= new SGameQController();
	else if (whichController.equalsIgnoreCase("Model"))
	    controller= new SGameModelController(this);
	else if (whichController.equalsIgnoreCase("FunctionApproximation"))
	    controller= new SGameFAController(this);
	else if (whichController.equalsIgnoreCase("Adversary"))
	    controller= new SGameAdvController();
	else {
	    controller= new SGameController();
	    System.out.println("Unknown Controller: "+whichController);
	}
	environment = controller.getEnvironment();    
	xDim = environment.xDim;
	yDim = environment.yDim; 

	graphPanel = new GridPanel();

	// setLayout(new FlowLayout());
	JPanel pan = new JPanel();
	pan.setLayout(new BoxLayout(pan,BoxLayout.Y_AXIS));

	JPanel directionPanel = new JPanel();
	directionPanel.setLayout(new BorderLayout());

	JPanel titlePanel = new JPanel();
	reportZeroLabel = new JLabel("David's game - "+controller.title);
	titlePanel.add(reportZeroLabel);
	pan.add(titlePanel);

	JButton up = new JButton("Up");
	directionPanel.add(up,BorderLayout.NORTH);
	up.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dostep(0);
		    repaint();
		}
	    });


	JButton right = new JButton("Right");
	directionPanel.add(right,BorderLayout.EAST);
	right.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dostep(1);
		    repaint();
		}
	    });


	JButton down = new JButton("Down");
	directionPanel.add(down,BorderLayout.SOUTH);
	down.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dostep(2);
		    repaint();
		}
	    });


	JButton left = new JButton("Left");
	directionPanel.add(left,BorderLayout.WEST);
	left.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dostep(3);
		    repaint();
		}
	    });


	pan.add(Box.createVerticalGlue());

	pan.add(directionPanel);

	pan.add(Box.createVerticalGlue());
	pan.add(Box.createVerticalStrut(strutSize));

	JPanel discountPanel = new JPanel();
	discountPanel.add(new JLabel("Discount"));

	JButton decrement = new JButton("-");
	discountPanel.add(decrement);
	decrement.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    discountField.setText(  df.format(Double.parseDouble(discountField.getText())-0.1)); 
		    repaint();
		}
	    }
	    );

	discountField = new JTextField(Double.toString(controller.discount),3);
	discountPanel.add(discountField);

	JButton increment = new JButton("+");
	discountPanel.add(increment);
	increment.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    discountField.setText(  df.format(Double.parseDouble(discountField.getText())+0.1)); 
		    repaint();
		}
	    }
	    );

	pan.add(discountPanel);



	pan.add(Box.createVerticalGlue());
	pan.add(Box.createVerticalStrut(strutSize));

	JPanel stepPanel= new JPanel();

	JButton step = new JButton("Step");
	stepPanel.add(step);

	stepField = new JTextField("400",7);
	stepPanel.add(stepField);

	step.setFont(new Font("SansSerif",1,20));
	step.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    doSteps(Integer.parseInt(stepField.getText()));
		    repaint();
		}
	    }
	    );
	pan.add(stepPanel);

	JPanel greedyPanel = new JPanel();
	greedyPanel.add(new JLabel("Greedy Exploit"));
	greedyField = new  JTextField("80",2);
	greedyPanel.add(greedyField);
	greedyPanel.add(new JLabel("%"));
	pan.add(greedyPanel);

	pan.add(Box.createVerticalGlue());
	pan.add(Box.createVerticalStrut(strutSize));

	JPanel alphaPanel = new JPanel();
	final JCheckBox alphaCheckBox = new JCheckBox(alphaText,controller.alphaFixed);
	alphaCheckBox.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    controller.alphaFixed = alphaCheckBox.isSelected();
		}
	    });

	alphaPanel.add(alphaCheckBox);
	alphaField = new JTextField(controller.alpha+"",3);
	alphaPanel.add(alphaField);
	pan.add(alphaPanel);

	pan.add(Box.createVerticalGlue());
	pan.add(Box.createVerticalStrut(strutSize));

	JPanel resetPanel = new JPanel();

	JButton reset = new JButton("Reset");
	reset.setFont(new Font("SansSerif",1,20));
	reset.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    doreset();
		    repaint();
		}
	    }
	    );

	resetPanel.add(reset);
	pan.add(resetPanel);


	JPanel initialValuePanel = new JPanel();
	initialValuePanel.add(new JLabel("Initial Value"));
	initialValueField = new JTextField("0.0",3);
	initialValuePanel.add(initialValueField);
	pan.add(initialValuePanel);

	pan.add(Box.createVerticalGlue());
	pan.add(Box.createVerticalStrut(strutSize));

	JPanel reportStepsPanel = new JPanel();
	reportStepsLabel = new JLabel("Number of steps: 0");
	reportStepsPanel.add(reportStepsLabel);
	pan.add(reportStepsPanel);

	JPanel reportRewardsPanel = new JPanel();
	reportRewardsLabel = new JLabel("Total reward received: 0");
	reportRewardsPanel.add(reportRewardsLabel);
	pan.add(reportRewardsPanel);

	JPanel reportMinPanel = new JPanel();
	reportMinLabel = new JLabel("Min reward is 0 at step 0");
	reportMinPanel.add(reportMinLabel);
	pan.add(reportMinPanel);

	JPanel reportZeroPanel = new JPanel();
	reportZeroLabel = new JLabel("Zero crossing a step: 0");
	reportZeroPanel.add(reportZeroLabel);
	pan.add(reportZeroPanel);


	JPanel fontSizePanel = new JPanel();
	fontSizePanel.add(new JLabel("Font Size"));

	JButton smaller = new JButton("-");
	fontSizePanel.add(smaller);
	smaller.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    fontSize--;
		    myFont = new Font("SansSerif", Font.PLAIN, fontSize);
		    repaint();
		}
	    });
	// 		JButton resetFontSize = new JButton("12");
	// 		fontSizePanel.add(resetFontSize);
	// 		resetFontSize.addActionListener(new ActionListener()
	// 		{
	// 			public void actionPerformed(ActionEvent event)
	// 			{
	// 				fontSize = 12;
	//                              myFont = new Font("SansSerif", Font.PLAIN, fontSize);
	// 				repaint();
	// 			}
	// 		});

	JButton brighter = new JButton("+");
	fontSizePanel.add(brighter);
	brighter.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    fontSize++;
		    myFont = new Font("SansSerif", Font.PLAIN, fontSize);
		    repaint();
		}
	    });


	JPanel sizePanel = fontSizePanel;
	sizePanel.add(new JLabel("Grid Size: "));

	JButton shrink = new JButton("-");
	sizePanel.add(shrink);
	shrink.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    sqsize -= 5;
		    gridDimension.setSize(sqsize*xDim,sqsize*yDim);
		    graphPanel.setPreferredSize(gridDimension);
		    graphPanel.revalidate();
		    repaint();
		}
	    });

	JButton grow = new JButton("+");
	sizePanel.add(grow);
	grow.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    sqsize += 5;
		    gridDimension.setSize(sqsize*xDim,sqsize*yDim);
		    graphPanel.setPreferredSize(gridDimension);
		    graphPanel.revalidate();
		    repaint();
		}
	    });

	final JCheckBox tracingCheckBox = new JCheckBox("Trace on console");
	tracingCheckBox.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    environment.tracing = tracingCheckBox.isSelected();
		}
	    });

	sizePanel.add(tracingCheckBox);

	final JCheckBox showCountsCheckBox = new JCheckBox("Show counts");
	showCountsCheckBox.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    showCounts = showCountsCheckBox.isSelected();
		    repaint();
		}
	    });
	if (showCountsOption)
	    sizePanel.add(showCountsCheckBox);
	// pan.add(sizePanel);



	getContentPane().add(pan,"East");
	getContentPane().add(fontSizePanel,"South");
	JScrollPane gridScrollPane = new JScrollPane(graphPanel);
	getContentPane().add(gridScrollPane,"Center");
	doreset();
	repaint();  // ????
    }

    private class GridPanel extends JPanel
    {
	public GridPanel()
	{
	    setPreferredSize(gridDimension);
	}

	public void paintComponent(Graphics g)
	{     
	    super.paintComponent(g);

	    g.setFont(myFont);

	    g.setColor(Color.black);
	    g.fillRect(0,0,xDim*sqsize,yDim*sqsize);

	    g.setColor(Color.white);
	    for (int counter = 1 ; counter <= xDim; counter++) 
		g.drawLine(sqsize*counter,0,sqsize*counter,yDim*sqsize);
	    for (int counter = 1 ; counter <= yDim; counter++) 
		g.drawLine(0,sqsize*counter,xDim*sqsize,sqsize*counter);

	    // draw monsters
	    if (environment.m21)
		colorSquare(g,2,1,Color.red);
	    if (environment.m42)
		colorSquare(g,4,2,Color.red);
	    if (environment.m03)
		colorSquare(g,0,3,Color.red);
	    if (environment.m13)
		colorSquare(g,1,3,Color.red);
	    if (environment.m33)
		colorSquare(g,3,3,Color.red);
	    // draw prize
	    if (environment.prize < 4)
		colorSquare(g,environment.prize%2*(xDim-1),environment.prize/2*(yDim-1),Color.cyan);
	    // draw the repair square
	    g.setColor(Color.magenta);
	    g.fillRect(sqsize,0,sqsize,sqsize);


	    // draw the arrows
	    for (int xval=0 ; xval <xDim; xval++) {
		for (int yval=0; yval < yDim; yval++) {
		    double val=controller.qvalue(xval,yval,0);
		    for (int i=1; i<4 ;i++) {
			if (val < controller.qvalue(xval,yval,i ))
			    val = controller.qvalue(xval,yval,i );
		    }
		    g.setColor(Color.blue);
		    if (val==controller.qvalue(xval,yval,0)){
			int uptrix[] = {xval*sqsize+sqsize/2-twid,
					xval*sqsize+sqsize/2+twid,
					xval*sqsize+sqsize/2};
			int uptriy[] = {yval*sqsize+sqsize/2,
					yval*sqsize+sqsize/2,
					yval*sqsize};
			g.fillPolygon(uptrix,uptriy,3);
		    }
		    if (val==controller.qvalue(xval,yval,1)){
			int uptriy[] = {yval*sqsize+sqsize/2-twid,
					yval*sqsize+sqsize/2+twid,
					yval*sqsize+sqsize/2};
			int uptrix[] = {xval*sqsize+sqsize/2,
					xval*sqsize+sqsize/2,
					(xval+1)*sqsize};
			g.fillPolygon(uptrix,uptriy,3);
		    }
		    if (val==controller.qvalue(xval,yval,2)){
			int uptrix[] = {xval*sqsize+sqsize/2-twid,
					xval*sqsize+sqsize/2+twid,
					xval*sqsize+sqsize/2};
			int uptriy[] = {yval*sqsize+sqsize/2,
					yval*sqsize+sqsize/2,
					(yval+1)*sqsize};
			g.fillPolygon(uptrix,uptriy,3);
		    }
		    if (val==controller.qvalue(xval,yval,3)){
			int uptriy[] = {yval*sqsize+sqsize/2-twid,
					yval*sqsize+sqsize/2+twid,
					yval*sqsize+sqsize/2};
			int uptrix[] = {xval*sqsize+sqsize/2,
					xval*sqsize+sqsize/2,
					xval*sqsize};
			g.fillPolygon(uptrix,uptriy,3);
		    }

		    // write the Q-values
		    g.setColor(Color.white);
		    g.drawString(df.format(controller.qvalue(xval,yval,0))+(showCounts?" ("+controller.getCounts(xval,yval,0)+")":""),xval*sqsize+sqsize/3,yval*sqsize+sqsize/3);
		    g.drawString(df.format(controller.qvalue(xval,yval,1))+(showCounts?" ("+controller.getCounts(xval,yval,1)+")":""),xval*sqsize+2+sqsize/2,yval*sqsize+2*sqsize/3);
		    g.drawString(df.format(controller.qvalue(xval,yval,2))+(showCounts?" ("+controller.getCounts(xval,yval,2)+")":""),xval*sqsize+sqsize/3,(yval+1)*sqsize-1);
		    g.drawString(df.format(controller.qvalue(xval,yval,3))+(showCounts?" ("+controller.getCounts(xval,yval,3)+")":""),xval*sqsize+2,yval*sqsize+2*sqsize/3);

		};
	    };

	    // draw interior walls
	    g.setColor(Color.white);
	    g.fillRect(sqsize-3,0,7,sqsize*2);
	    g.fillRect(2*sqsize-3,0,7,sqsize);
	    // draw agent
	    if (environment.damaged)
		g.setColor(Color.pink);
	    else
		g.setColor(Color.yellow);
	    g.fillOval(environment.currX*sqsize+sqsize/3, environment.currY*sqsize+sqsize/3, 
		       sqsize-2*(sqsize/3), sqsize-2*(sqsize/3));

	    // show values for parameters
	    double[] valsToDisplay = controller.toDisplay();
	    g.setColor(Color.black);
	    int lineSpacing = fontSize+2;
	    if (valsToDisplay != null)
		for (int i=0; i<valsToDisplay.length; i++)
		    g.drawString(i+": "+valsToDisplay[i],5*sqsize+3,(1+i)*lineSpacing);
	    reportStepsLabel.setText("Number of steps: "+environment.numberOfSteps);
	    reportRewardsLabel.setText("Total reward received: "+environment.totalReward);
	    reportMinLabel.setText("Min reward is "+environment.minReward+" at step "+environment.minStep);	    
	    reportZeroLabel.setText("Zero crossing a step: "+environment.zeroCrossing);
	}
    }


    public void doreset()
    {     
	controller.doreset(Double.parseDouble(initialValueField.getText()));
	environment.doreset();
    }

    public void dostep(int action) {
	controller.alpha =  Double.parseDouble(alphaField.getText());
	controller.discount = Double.parseDouble(discountField.getText());
	controller.dostep(action);
    }



    private void doSteps(int count) {
	controller.alpha =  Double.parseDouble(alphaField.getText());
	controller.discount = Double.parseDouble(discountField.getText());

	controller.doSteps(count,Double.parseDouble(greedyField.getText())/100.0);
	repaint();

    }

    private void colorSquare(Graphics g,int xval,int yval,Color col) {
	g.setColor(col);
	g.fillRect(xval*sqsize,yval*sqsize,sqsize,sqsize);
    }


}

