import java.awt.*;
import java.awt.event.*;
import java.awt.geom.*;
//import java.applet.*;
import java.text.*;
import java.util.*;
import javax.swing.*;

/**
 * This applet demonstrates a simple grid-world game.  It is complex
 * enough so that simple, rule-based strategies are difficult to
 * write, but it is small enough so that state-based MDP and RL
 * algorihms can work. There are lots of ways it can be extended.
 * It isn't designed to be general or reusable.
 * <P>
 * Copyright (C) 2008  <A HREF="http://www.cs.ubc.ca/spider/poole/">David Poole</A>.
 * <p>
 * This program gives the GUI. The GUI is in <A
 * HREF=SGameGUI.java">SGameGUI.java</A>. The environemnt code is at
 * <A HREF="SGameEnv.java">SGameEnv.java</A>. The controller is at <A
 * HREF="SGameController.java">SGameController.java</A>.
 * <p>
 This program is free software; you can redistribute it and/or
 modify it under the terms of the GNU General Public License
 as published by the Free Software Foundation; either version 2
 of the License, or (at your option) any later version.
 <p>
 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.
 <p>
 You should have received a copy of the GNU General Public License
 along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.


 * @author David Poole  poole@cs.ubc.ca
 * @version 0.41 2007-09-09 */



public class MASLearningGUI extends JApplet
{
    WolfLearningAgent rowAgent;
    WolfLearningAgent columnAgent;
    SampleSoccerGame game = new SampleSoccerGame();

    // current run
    RunData currentRun;
    Vector<RunData> runHistory;

    //payoffs for row player
    JTextField r00Field;
    JTextField r01Field;
    JTextField r10Field;
    JTextField r11Field;
    //payoffs for column player
    JTextField c00Field;
    JTextField c01Field;
    JTextField c10Field;
    JTextField c11Field;

    JTextField alphaField;
    JTextField deltaField;
    JTextField stepField;
    JTextField[] rowParameterField;
    JTextField[] columnParameterField;

    JLabel reportStepsLabel;
    JLabel reportRowProbLabel;
    JLabel reportColumnProbLabel;
    JLabel reportRowValsLabel;
    JLabel reportColValsLabel;
    JLabel rowParameterLabel;
    JLabel columnParameterLabel;
    JLabel reportZeroLabel;

    int fieldWidth=6;  // lengh of text box for parameters

    int graphSize = 500;
    int topBorder=20;
    int leftBorder=70;
    int fontSize = 14;
    Font myFont = new Font("SansSerif", Font.PLAIN, fontSize);
    int dotSize=3;

    DecimalFormat df = new DecimalFormat("0.##");
    Dimension gridDimension = new Dimension(graphSize+leftBorder,graphSize+topBorder);

    GridPanel graphPanel;

    boolean tracing=false;

    public void init()
    {
	rowAgent= new WolfLearningAgent(2);
	columnAgent= new WolfLearningAgent(2);
	//	game = new SampleSoccerGame();

	rowParameterField = new JTextField[rowAgent.parameter.length];
	columnParameterField = new JTextField[columnAgent.parameter.length];
	runHistory = new Vector();
	doreset();

	graphPanel = new GridPanel();
	

	// setLayout(new FlowLayout());
	JPanel pan = new JPanel();
	pan.setLayout(new BoxLayout(pan,BoxLayout.Y_AXIS));

	JPanel rowTitlePanel = new JPanel();
	reportZeroLabel = new JLabel("Row Payoff   Action 0    Action 1 ");
	rowTitlePanel.add(reportZeroLabel);
	pan.add(rowTitlePanel);

	//		pan.add(Box.createVerticalStrut(strutSize));

	JPanel rowAction1Panel = new JPanel();
	rowAction1Panel.add(new JLabel("Action 1"));


	r10Field = new JTextField(""+game.r10,3);
	rowAction1Panel.add(r10Field);
	r11Field = new JTextField(""+game.r11,3);
	rowAction1Panel.add(r11Field);
	pan.add(rowAction1Panel);

	JPanel rowAction0Panel = new JPanel();
	rowAction0Panel.add(new JLabel("Action 0"));


	r00Field = new JTextField(""+game.r00,3);
	rowAction0Panel.add(r00Field);
	r01Field = new JTextField(""+game.r01,3);
	rowAction0Panel.add(r01Field);

	pan.add(rowAction0Panel);

	pan.add(Box.createVerticalGlue());
	//		pan.add(Box.createVerticalStrut(strutSize));

	JPanel columnTitlePanel = new JPanel();
	reportZeroLabel = new JLabel("Column Payoff   Action 0    Action 1 ");
	columnTitlePanel.add(reportZeroLabel);
	pan.add(columnTitlePanel);

	//		pan.add(Box.createVerticalStrut(strutSize));

	JPanel columnAction1Panel = new JPanel();
	columnAction1Panel.add(new JLabel("Action 1"));


	c10Field = new JTextField(""+game.c10,3);
	columnAction1Panel.add(c10Field);
	c11Field = new JTextField(""+game.c11,3);
	columnAction1Panel.add(c11Field);

	pan.add(columnAction1Panel);
	JPanel columnAction0Panel = new JPanel();
	columnAction0Panel.add(new JLabel("Action 0"));


	c00Field = new JTextField(""+game.c00,3);
	columnAction0Panel.add(c00Field);
	c01Field = new JTextField(""+game.c01,3);
	columnAction0Panel.add(c01Field);

	pan.add(columnAction0Panel);

	pan.add(Box.createVerticalGlue());
	//		pan.add(Box.createVerticalStrut(strutSize));

	JPanel stepPanel= new JPanel();

	JButton step = new JButton("Step");
	stepPanel.add(step);
	stepField = new JTextField("10000",7);
	stepPanel.add(stepField);

	step.setFont(new Font("SansSerif",1,20));
	step.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    doSteps(Integer.parseInt(stepField.getText()));
		    repaint();
		}
	    }
	    );
	pan.add(stepPanel);

	JPanel resetPanel= new JPanel();

	JButton reset = new JButton("Random Restart");
	resetPanel.add(reset);

	//	reset.setFont(new Font("SansSerif",1,20));
	reset.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    doreset();
		    updateAgentInfo();
		    repaint();
		}
	    }
	    );
	pan.add(resetPanel);

	JPanel clearPanel= new JPanel();

	JButton clear = new JButton("Clear");
	clearPanel.add(clear);

	// clear.setFont(new Font("SansSerif",1,20));
	clear.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    runHistory = new Vector();
		    doreset();
		    updateAgentInfo();
		    repaint();
		}
	    }
	    );
	pan.add(clearPanel);
	

	JPanel rowParameterPanel = new JPanel();
	rowParameterLabel = new JLabel("Row Agent Parameters:");
	rowParameterPanel.add(rowParameterLabel);
	pan.add(rowParameterPanel);

	JPanel[] rowParameterPanels = new JPanel[rowAgent.parameter.length];
	
	for (int i=0; i<rowAgent.parameter.length; i++) {
	    rowParameterPanels[i] = new JPanel();
	    rowParameterPanels[i].add(new JLabel(rowAgent.parameter[i].parameterName));
	    rowParameterField[i] = new JTextField(""+rowAgent.parameter[i].value,fieldWidth);
	    rowParameterPanels[i].add(rowParameterField[i]);
	    pan.add(rowParameterPanels[i]);	    
	}

	JPanel columnParameterPanel = new JPanel();
	columnParameterLabel = new JLabel("Column Agent Parameters:");
	columnParameterPanel.add(columnParameterLabel);
	pan.add(columnParameterPanel);

	JPanel[] columnParameterPanels = new JPanel[columnAgent.parameter.length];
	
	for (int i=0; i<columnAgent.parameter.length; i++) {
	    columnParameterPanels[i] = new JPanel();
	    columnParameterPanels[i].add(new JLabel(columnAgent.parameter[i].parameterName));
	    columnParameterField[i] = new JTextField(""+columnAgent.parameter[i].value,fieldWidth);
	    columnParameterPanels[i].add(columnParameterField[i]);
	    pan.add(columnParameterPanels[i]);	    
	}

	JPanel reportRowProbPanel = new JPanel();
	reportRowProbLabel = new JLabel("Prob (Row does Action 1)="+df.format(rowAgent.prob[1]));
	reportRowProbPanel.add(reportRowProbLabel);
	pan.add(reportRowProbPanel);

	JPanel reportColumnProbPanel = new JPanel();
	reportColumnProbLabel = new JLabel("Prob(Column does Action 1)="+df.format(columnAgent.prob[1]));
	reportColumnProbPanel.add(reportColumnProbLabel);
	pan.add(reportColumnProbPanel);

	JPanel reportRowValsPanel = new JPanel();
	reportRowValsLabel = new JLabel("Row Vals="+df.format(columnAgent.value[0])+", "+df.format(columnAgent.value[1]));
	reportRowValsPanel.add(reportRowValsLabel);
	pan.add(reportRowValsPanel);

	JPanel reportColValsPanel = new JPanel();
	reportColValsLabel = new JLabel("Col Vals="+df.format(columnAgent.value[0])+", "+df.format(columnAgent.value[1]));
	reportColValsPanel.add(reportColValsLabel);
	pan.add(reportColValsPanel);


	JPanel fontSizePanel = new JPanel();
	fontSizePanel.add(new JLabel("Plot Size"));

	JButton smallerGraph = new JButton("-");
	fontSizePanel.add(smallerGraph);
	smallerGraph.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    graphSize-=50;
		    repaint();
		}
	    });
	// 		JButton resetFontSize = new JButton("12");
	// 		fontSizePanel.add(resetFontSize);
	// 		resetFontSize.addActionListener(new ActionListener()
	// 		{
	// 			public void actionPerformed(ActionEvent event)
	// 			{
	// 				fontSize = 12;
	//                              myFont = new Font("SansSerif", Font.PLAIN, fontSize);
	// 				repaint();
	// 			}
	// 		});

	JButton biggerGraph = new JButton("+");
	fontSizePanel.add(biggerGraph);
	biggerGraph.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    graphSize+=50;
		    repaint();
		}
	    });

	//	JPanel fontSizePanel = new JPanel();
	fontSizePanel.add(new JLabel("Font Size"));

	JButton dimmer = new JButton("-");
	fontSizePanel.add(dimmer);
	dimmer.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    fontSize--;
		    myFont = new Font("SansSerif", Font.PLAIN, fontSize);
		    repaint();
		}
	    });
	// 		JButton resetFontSize = new JButton("12");
	// 		fontSizePanel.add(resetFontSize);
	// 		resetFontSize.addActionListener(new ActionListener()
	// 		{
	// 			public void actionPerformed(ActionEvent event)
	// 			{
	// 				fontSize = 12;
	//                              myFont = new Font("SansSerif", Font.PLAIN, fontSize);
	// 				repaint();
	// 			}
	// 		});

	JButton brighter = new JButton("+");
	fontSizePanel.add(brighter);
	brighter.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    fontSize++;
		    myFont = new Font("SansSerif", Font.PLAIN, fontSize);
		    repaint();
		}
	    });

	JPanel sizePanel = fontSizePanel;
	sizePanel.add(new JLabel("Dot Size: "));

	JButton shrink = new JButton("-");
	sizePanel.add(shrink);
	shrink.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dotSize -= 1;
		    if(dotSize <= 0) dotSize=1;
// 		    gridDimension.setSize(sqsize,sqsize);
// 		    graphPanel.setPreferredSize(gridDimension);
// 		    graphPanel.revalidate();
		    repaint();
		}
	    });

	JButton grow = new JButton("+");
	sizePanel.add(grow);
	grow.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    dotSize += 1;
// 		    gridDimension.setSize(sqsize,sqsize);
// 		    graphPanel.setPreferredSize(gridDimension);
// 		    graphPanel.revalidate();
		    repaint();
		}
	    });

	final JCheckBox tracingCheckBox = new JCheckBox("Trace on console");
	tracingCheckBox.addActionListener(new ActionListener()
	    {
		public void actionPerformed(ActionEvent event)
		{
		    tracing = tracingCheckBox.isSelected();
		}
	    });

	sizePanel.add(tracingCheckBox);

	// pan.add(sizePanel);



	getContentPane().add(pan,"East");
	getContentPane().add(fontSizePanel,"South");
	JScrollPane gridScrollPane = new JScrollPane(graphPanel);
	getContentPane().add(gridScrollPane,"Center");
	repaint();  // ????
    }

    private class GridPanel extends JPanel
    {
	public GridPanel()
	{
	    setPreferredSize(gridDimension);
	}

	public void paintComponent(Graphics g)
	{     
	    super.paintComponent(g);

	    g.setFont(myFont);

 	    g.setColor(Color.black);
 	    g.fillRect(leftBorder,topBorder,1,graphSize);
 	    g.fillRect(leftBorder,topBorder+graphSize,graphSize,1);
	    g.drawString("0",leftBorder-3,topBorder+graphSize+fontSize+5);
	    g.drawString("1",leftBorder+graphSize-3,topBorder+graphSize+fontSize+5);
	    g.drawString("Prob (row does Action 1)",leftBorder+graphSize/3,topBorder+graphSize+fontSize+5);

	    g.drawString("0",leftBorder-fontSize/2-5,topBorder+graphSize+fontSize/2);
	    g.drawString("1",leftBorder-fontSize/2-5,topBorder+fontSize/2);
	    Graphics2D g2 = (Graphics2D) g;
	    // Get the current transform
	    AffineTransform saveAT = g2.getTransform();
	    // Perform transformation
 	    g2.rotate(-Math.PI/2,leftBorder-fontSize-5,topBorder+graphSize*2/3);
	    g2.drawString("Prob (column does Action 1)",leftBorder-fontSize-5,topBorder+graphSize*2/3);
	    // Restore original transform
	    g2.setTransform(saveAT);
	    for(RunData run : runHistory) {
		g.setColor(run.runColor);
		for(PointData point : run.points) {
		    g.fillOval(leftBorder+(int)(point.xval*graphSize),
			       topBorder+graphSize-(int)(point.yval*graphSize),
			       dotSize,dotSize);
		}
	    }

	}
    }

    private class RunData {
	Color runColor = new Color((new Double(Math.random())).floatValue(),
				   (new Double(Math.random())).floatValue(),
				   (new Double(Math.random())).floatValue());
	Vector<PointData> points = new Vector();
    }

    private class PointData {
	double xval;
	double yval;
	PointData(double x, double y) {
	    xval=x;
	    yval=y;
	}
    }

    public void doreset() {     
	currentRun = new RunData();
	runHistory.add(currentRun);
	columnAgent.randomizeStrategy();
	rowAgent.randomizeStrategy();
	plotProbs(columnAgent.prob[1],rowAgent.prob[1]);
    }





    private void doSteps(int count) {
	updateParameters(rowParameterField,rowAgent.parameter);
	updateParameters(columnParameterField,columnAgent.parameter);
	double[][] r = new double[2][2];
	r[0][0]=Double.parseDouble(r00Field.getText());
	r[0][1]=Double.parseDouble(r01Field.getText());
	r[1][0]=Double.parseDouble(r10Field.getText());
	r[1][1]=Double.parseDouble(r11Field.getText());
	double[][] c = new double[2][2];
	c[0][0]=Double.parseDouble(c00Field.getText());
	c[0][1]=Double.parseDouble(c01Field.getText());
	c[1][0]=Double.parseDouble(c10Field.getText());
	c[1][1]=Double.parseDouble(c11Field.getText());
	for(int i=0; i<count; i++) {
	    int a1 = rowAgent.getAction();
	    int a2 = columnAgent.getAction();
	    rowAgent.tellReward(r[a1][a2]);
	    columnAgent.tellReward(c[a1][a2]);
	    plotProbs(columnAgent.prob[1],rowAgent.prob[1]);
	}
	updateAgentInfo();
    }

    private void updateParameters(JTextField[] fields, Parameter[] params) {
	for (int i=0; i<params.length; i++)
	    params[i].value=Double.parseDouble(fields[i].getText());
    }

    private void plotProbs(double x, double y) {
	currentRun.points.add(new PointData(x,y));
	if (tracing)
	    System.out.println(x+" "+y);
    }

    private void updateAgentInfo() {
	reportRowProbLabel.setText("Prob (Row does Action 1)="+df.format(rowAgent.prob[1]));
	reportColumnProbLabel.setText("Prob(Column does Action 1)="+df.format(columnAgent.prob[1]));
	reportRowValsLabel.setText("Row Vals="+df.format(rowAgent.value[0])+", "+df.format(rowAgent.value[1]));
	reportColValsLabel.setText("Col Vals="+df.format(columnAgent.value[0])+", "+df.format(columnAgent.value[1]));
    }
	

//     private void colorSquare(Graphics g,int xval,int yval,Color col) {
// 	g.setColor(col);
// 	g.fillRect(xval*sqsize,yval*sqsize,sqsize,sqsize);
//     }


}

