栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

sarsa

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

sarsa

文章目录
  • sarsa: state-action-reward-state-action
  • sarsa-lambda

sarsa: state-action-reward-state-action

和q-learning相似,但是在 s 2 s_2 s2​时,不是取 max ⁡ Q max Q maxQ,而是直接根据policy选择一个action。这也是为什么是on-policy算法,而q-learning是off-policy

sarsa算法思路:

q-learning相比是一种更加贪婪的算法,因为算法更新的时候使用的是 max ⁡ Q max Q maxQ

具体的代码

for eposide in range(eposides):
	observation = env.reset()
	action = RL.choose_action(observation)
	while True:
		env.render()
		observation_next, reward, done = env.step(action)
		action_next = RL.choose_action(observation_next)
		RL.learn(observation, action, reward, observation_next, action_next)
		observation = observation_next
		action = action_next
		if done:
			break
print('Game over')
env.destroy()

同样agent的框架大概为:

class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示继承关系

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新 q_table
sarsa-lambda

上面的过程可以很清晰地显示出:sarsa是一个on-policy的学习方法,每次都会根据current policy生成新的sample,是单步更新的

sarsa-lambda的改变,相对sarsa:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ Q ( s ′ , a ′ ) − Q ( s , a ) ] ⏟ 增加一个系数 Q(s,a) leftarrow Q(s,a) + underbrace{alpha[r + gamma Q(s', a') - Q(s, a)]}_{text{增加一个系数}} Q(s,a)←Q(s,a)+增加一个系数 α[r+γQ(s′,a′)−Q(s,a)]​​
这个系数具体的值与当前位置到回合结束的距离、 λ lambda λ、衰减系数 γ gamma γ相关,其逻辑具体为:

具体的代码

sarsa-lambda类

class SarsaLambdaTable(RL): # 继承 RL class
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

        # 上述lambda
        self.lambda_ = trace_decay
        self.eligibility_trace = self.q_table.copy()    # 空的 eligibility trace 表
    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            to_be_append = pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)
    def learn(self, s, a, r, s_, a_):
        # 这部分和 Sarsa 一样
        self.check_state_exist(s_)
        q_predict = self.q_table.ix[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.ix[s_, a_]
        else:
            q_target = r
        error = q_target - q_predict

        # 这里开始不同:
        # 对于经历过的 state-action, 我们让他+1, 证明他是得到 reward 路途中不可或缺的一环
        self.eligibility_trace.ix[s, a] += 1

        # Q table 更新
        self.q_table += self.lr * error * self.eligibility_trace

        # 随着时间衰减 eligibility trace 的值, 离获取 reward 越远的步, 他的"不可或缺性"越小
        self.eligibility_trace *= self.gamma*self.lambda_

更新eligibility_trace还有其他方式,如:

self.eligibility_trace.ix[s, :] *= 0
self.eligibility_trace.ix[s, a] = 1

注意:在每回合开始的时候,eligibility_trace需要清零

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/269837.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号