- sarsa: state-action-reward-state-action
- sarsa-lambda
和q-learning相似,但是在 s 2 s_2 s2时,不是取 max Q max Q maxQ,而是直接根据policy选择一个action。这也是为什么是on-policy算法,而q-learning是off-policy
sarsa算法思路:
q-learning相比是一种更加贪婪的算法,因为算法更新的时候使用的是
max
Q
max Q
maxQ
具体的代码
for eposide in range(eposides):
observation = env.reset()
action = RL.choose_action(observation)
while True:
env.render()
observation_next, reward, done = env.step(action)
action_next = RL.choose_action(observation_next)
RL.learn(observation, action, reward, observation_next, action_next)
observation = observation_next
action = action_next
if done:
break
print('Game over')
env.destroy()
同样agent的框架大概为:
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
sarsa-lambda
上面的过程可以很清晰地显示出:sarsa是一个on-policy的学习方法,每次都会根据current policy生成新的sample,是单步更新的
sarsa-lambda的改变,相对sarsa:
Q
(
s
,
a
)
←
Q
(
s
,
a
)
+
α
[
r
+
γ
Q
(
s
′
,
a
′
)
−
Q
(
s
,
a
)
]
⏟
增加一个系数
Q(s,a) leftarrow Q(s,a) + underbrace{alpha[r + gamma Q(s', a') - Q(s, a)]}_{text{增加一个系数}}
Q(s,a)←Q(s,a)+增加一个系数
α[r+γQ(s′,a′)−Q(s,a)]
这个系数具体的值与当前位置到回合结束的距离、
λ
lambda
λ、衰减系数
γ
gamma
γ相关,其逻辑具体为:
具体的代码
sarsa-lambda类
class SarsaLambdaTable(RL): # 继承 RL class
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
# 上述lambda
self.lambda_ = trace_decay
self.eligibility_trace = self.q_table.copy() # 空的 eligibility trace 表
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# also update eligibility trace
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
# 这部分和 Sarsa 一样
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.ix[s_, a_]
else:
q_target = r
error = q_target - q_predict
# 这里开始不同:
# 对于经历过的 state-action, 我们让他+1, 证明他是得到 reward 路途中不可或缺的一环
self.eligibility_trace.ix[s, a] += 1
# Q table 更新
self.q_table += self.lr * error * self.eligibility_trace
# 随着时间衰减 eligibility trace 的值, 离获取 reward 越远的步, 他的"不可或缺性"越小
self.eligibility_trace *= self.gamma*self.lambda_
更新eligibility_trace还有其他方式,如:
self.eligibility_trace.ix[s, :] *= 0 self.eligibility_trace.ix[s, a] = 1
注意:在每回合开始的时候,eligibility_trace需要清零



