import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
?
#迷宮的初始狀態
fig=plt.figure(figsize=(5,5))
ax=plt.gca()
?
plt.plot([1,1],[0,1],color='red',linewidth=2)
plt.plot([1,2],[2,2],color='red',linewidth=2)
plt.plot([2,2],[2,1],color='red',linewidth=2)
plt.plot([2,3],[1,1],color='red',linewidth=2)
?
plt.text(0.5, 2.5, 'S0', size=14, ha='center')
plt.text(1.5, 2.5, 'S1', size=14, ha='center')
plt.text(2.5, 2.5, 'S2', size=14, ha='center')
plt.text(0.5, 1.5, 'S3', size=14, ha='center')
plt.text(1.5, 1.5, 'S4', size=14, ha='center')
plt.text(2.5, 1.5, 'S5', size=14, ha='center')
plt.text(0.5, 0.5, 'S6', size=14, ha='center')
plt.text(1.5, 0.5, 'S7', size=14, ha='center')
plt.text(2.5, 0.5, 'S8', size=14, ha='center')
plt.text(0.5, 2.3, 'START', ha='center')
plt.text(2.5, 0.3, 'GOAL', ha='center')
?
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
plt.tick_params(axis='both', which='both', bottom='off', top='off',
labelbottom='off', right='off', left='off', labelleft='off')
?
line, = ax.plot([0.5], [2.5], marker="o", color='g', markersize=60)
#theta的初始值
theta_0 = np.array([[np.nan, 1, 1, np.nan], # s0
[np.nan, 1, np.nan, 1], # s1
[np.nan, np.nan, 1, 1], # s2
[1, 1, 1, np.nan], # s3
[np.nan, np.nan, 1, 1], # s4
[1, np.nan, np.nan, np.nan], # s5
[1, np.nan, np.nan, np.nan], # s6
[1, 1, np.nan, np.nan], # s7、
])
#設定初始的動作價值函式
[a,b]=theta_0.shape#a,b分別表示theta_0的行和列
Q=np.random.rand(a,b)*theta_0#生成a行b列的0-1內的亂數矩陣,各狀態墻壁的位置為nan
#將初始策略theta_0轉為隨機策略
def simple_convert_into_pi_from_theta(theta):
[m,n]=theta.shape
pi=np.zeros((m,n))
for i in range(0,m):
pi[i,:]=theta[i,:]/np.nansum(theta[i,:])
pi=np.nan_to_num(pi)#將nan轉為數值0
return pi
?
pi_0=simple_convert_into_pi_from_theta(theta_0)
def get_action(s,Q,epsilon,pi_0):
direction = ["up", "right", "down", "left"]
if np.random.rand()<epsilon:
next_direction = np.random.choice(direction, p=pi_0[s, :])#根據策略pi0隨機選擇下一個移動方向
else:
next_direction=direction[np.nanargmax(Q[s,:])]#選擇最大Q對應的動作
if next_direction == "up":
action=0
elif next_direction == "right":
action=1
elif next_direction == "down":
action=2
elif next_direction == "left":
action=3
return action
def get_next_s(s,a,Q,epsilon,pi_0):
direction = ["up", "right", "down", "left"]
next_direction = direction[a]
if next_direction == "up":
next_s = s - 3
elif next_direction == "right":
next_s = s + 1
elif next_direction == "down":
next_s = s + 3
elif next_direction == "left":
next_s = s - 1
return next_s
#基于sarsa更新動作價值函式
def sarsa(s,a,r,next_s,next_a,Q,eta,gamma):
if next_s==8:
Q[s,a]=Q[s,a]+eta*(r-Q[s,a])
else:
Q[s,a]=Q[s,a]+eta*(r+gamma*Q[next_s,next_a]-Q[s,a])
return Q
#基于SARSA輸出最后的Q-table,和最后的Q值
def goal_maze_ret_s_a_Q(Q,epsilon,eta,gamma,pi):
s = 0#起點狀態,起點位置
a=next_a=get_action(s,Q,epsilon,pi)
s_a_history = [[0,np.nan]]#記錄智能體移動的s_a表
while (1):
a=next_a
s_a_history[-1][1]=a
next_s=get_next_s(s,a,Q,epsilon,pi)
s_a_history.append([next_s,np.nan]) #記錄每次行動后的狀態
if next_s == 8: #到達終點位置
r=1#到達終點,給予獎勵
next_a=np.nan
else:
r=0
next_a=get_action(s,Q,epsilon,pi)
Q=sarsa(s,a,r,next_s,next_a,Q,eta,gamma)
if next_s==8:
break
else:
s = next_s
return [s_a_history,Q]
eta=0.1
gamma=0.9
epsilon=0.5
v=np.nanmax(Q,axis=1)
is_continue=True
episode=1
?
while is_continue:#結束判斷條件
print('當前回合:'+str(episode))
epsilon=epsilon/2
[s_a_history,Q]=goal_maze_ret_s_a_Q(Q,epsilon,eta,gamma,pi_0)
new_v=np.nanmax(Q,axis=1)
print('狀態值變化量為:'+str(np.sum(np.abs(new_v-v))))
v=new_v
print(s_a_history)
print("規劃此路徑共需要" + str(len(s_a_history) - 1) +"步")
episode=episode+1
if episode>100:
break
當前回合:1
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-23-360980ca88f2> in <module>
9 print('當前回合:'+str(episode))
10 epsilon=epsilon/2
---> 11 [s_a_history,Q]=goal_maze_ret_s_a_Q(Q,epsilon,eta,gamma,pi_0)
12 new_v=np.nanmax(Q,axis=1)
13 print('狀態值變化量為:'+str(np.sum(np.abs(new_v-v))))
<ipython-input-8-de0a47c0aa27> in goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi)
6 a=next_a
7 s_a_history[-1][1]=a
----> 8 s=get_next_s(s,a,Q,epsilon)
9 s_a_history.append([next_s,np.nan]) #記錄每次行動后的狀態
10 if next_s == 8: #到達終點位置
TypeError: get_next_s() missing 1 required positional argument: 'pi_0'
uj5u.com熱心網友回復:
請大神們指點,感謝您轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/234036.html
上一篇:求助大佬!
