import numpy as np
import random
from tkinter import *
import time

tk = Tk()
tk.title('Q-Learning')
tk.wm_attributes('-topmost',1)

canvas = Canvas(tk,width=400,height=400,bd=0,highlightthickness=0)
for i in range(4):
    canvas.create_line(i*100,0,i*100,400)
    canvas.create_line(0,i*100,400,i*100)

trap1 = canvas.create_rectangle(200,0,300,100,fill='khaki')
trap2= canvas.create_rectangle(100,100,200,200,fill='khaki')
trap3 = canvas.create_rectangle(200,100,300,200,fill='khaki')
trap4 = canvas.create_rectangle(100,200,200,300,fill='khaki')
canvas.pack()
tk.update()

agent = canvas.create_rectangle(0,0,100,100,fill = 'orchid')

gamma = 0.8
R = np.array([[0,1,0,1],
                      [0,-10,1,-10],
                      [0,-10,1,1],
                      [0,1,-10,0],
                      [1,1,0,-10],
                      [1,-10,1,-10],
                      [-10,1,-10,1],
                      [1,1,-10,0],
                      [1,1,0,-10],
                      [-10,1,1,1],
                      [-10,1,-10,1],
                      [1,10,1,0],
                      [1,0,0,1],
                      [-10,0,1,1],
                      [1,0,1,10],
                      [1,0,1,0]])
Q = np.zeros((16, 4))
valid_action = np.array([[1, 3],
                          [1, 2,3],
                          [1,2,3],
                          [1, 2],
                          [0,1,3],
                          [0,1,2,3],
                          [0,1,2,3],
                          [0,1,2],
                          [0,1,3],
                          [0,1,2,3],
                          [0,1,2,3],
                          [0,1,2],
                          [0,3],
                          [0,2,3],
                          [0,2,3],
                          [0,2]])
transition_matrix = np.array([[-1,4,-1,1],
                              [-1, 5, 0, 2],
                              [-1, 6, 1 , 3],
                              [-1, 7, 2, -1],
                              [0,8,-1,5],
                              [1,9,4,6],
                              [2,10,5,7],
                              [3,11,6,-1],
                              [4,12,-1,9],
                              [5,13,8,10],
                              [6,14,9,11],
                              [7,15,10,-1],
                              [8,-1,-1,13],
                              [9,-1,12,14],
                              [10,-1,13,15],
                              [11,-1,14,-1]])



def start(s):
    row = s//4
    column =s%4
    canvas.coords(agent,column*100,row*100,(column+1)*100,(row+1)*100)
    tk.update()
    time.sleep(0.05)
def moves(a):
    if a==0:
            canvas.move(agent,0,-100)
    elif a ==1:
         canvas.move(agent,0,100)
    elif a == 2:
         canvas.move(agent,-100,0)
    else :
        canvas.move(agent,100,0)

    tk.update()
    time.sleep(0.01)
    
def QLearning():
    s = random.randint(0,15)
    start(s)
    while s != 15:
        a =  random.choice(valid_action[s])
        s1= transition_matrix[s][a]
        moves(a)
        Q[s,a] = R[s,a] + gamma*Q[s1].max()
        s = s1
for i in range(100):
    QLearning()
    
label = Label(tk,text='Training over!!!,start test.',bg='green',compound='center')
label.pack()
tk.update()
time.sleep(3)
def test( s ):
    print(s,end="")
    start(s)
    while s != 15:
        a = Q[s].argmax()     
        s = transition_matrix[s][a]
        moves(a)
        time.sleep(1)
        print("-> %d"%s,end="")
test(5)
tk.mainloop()