強化学習の初心者向けに、Q学習でCart-Poleを動かしてみます。

【強化学習入門】Q学習でCart-Poleを動かす(Google Colab+OpenAI Gym)

強化学習の入門として「Cart-Pole問題」というものがあります(強化学習の”Hello World!”と呼ばれているらしい)。Cart-Pole問題とは、棒(Pole)の支点を台車(Cart)に固定して、その台車を「左右」に動かすことでバランスをとり、棒が倒れないようにする問題のことです。

Cart-Poleのモデルを自身で実装してもよいのですが、少し面倒なので、この記事ではOpenAI Gymという有名なライブラリを利用します。

cartpole

そして、そのOpenAI GymのCart-PoleをQ学習で動かしてみます。
(Q学習と深層学習を組み合わせたDQNで動かしてみたい方々はコチラ

実行環境はGoogle Colaboratoryを想定しています。ブラウザでGoogle Colaboratoryをひらき、下図のように、本記事のプログラムをコピー&ペーストすれば実行できます。

cartpole_gcolab

この記事のポイント

  • 強化学習の入門者向け
  • 強化学習を実装レベルで勉強(i.e. アルゴリズムの詳細解説は省略)
  • Google Colaboratoryで動かせるコードをご紹介(i.e. ブラウザ上だけで完結可能)

パッケージのインストール

Google Colaboratoryで以下のコマンドをコピー&ペーストして、OpenAI Gymをインストールします。

!pip install gym

表示のためのパッケージもインストールします。

!apt update
!apt install xvfb
!pip install pyvirtualdisplay

まずはランダムに動かしてみる

OpenAI GymのCart-Poleの挙動を見るために、まずはランダムに動かしてみます。

Google Colaboratoryで動かすために、描画関連の実装が少し複雑になっています。その部分は理解する必要がないので、軽く読み飛ばすとよいでしょう。

ソースコード

## OpenAI Gym
import gym

## 表示用のパッケージ
import base64
import io
from gym.wrappers import Monitor
from IPython import display
from pyvirtualdisplay import Display


## 仮想のディスプレイを用意
virtual_display = Display()
virtual_display.start()

## Cart-Poleの環境を用意
## Google Colabで描画するためにgym.wrappers.Monitorを利用
env = Monitor(gym.make('CartPole-v0'),'./videos/', force=True)
print("env.observation_space.shape = ", env.observation_space.shape)
print("env.action_space.n = ", env.action_space.n)

obs = env.reset()

## ランダムに100ステップ動かす
for t in range(100):
	obs, reward, is_done, info = env.step(env.action_space.sample())
	print("obs = ", obs)
	print("reward = ", reward)
	print("info = ", info)

	## 終了判定(e.g. ポールが倒れた場合など)
	if is_done:
		print("Episode finished after {} timesteps".format(t+1))
		env.reset()
		break

## 以下は描画用
for frame in env.videos:
	print("frame = ", frame)
	video = io.open(frame[0], 'r+b').read()
	encoded = base64.b64encode(video)

	## 動画を埋め込み
	display.display(display.HTML(data="""
		<video controls>
		<source src="data:video/mp4;base64,{0}" type="video/mp4" />
		</video>
		""".format(encoded.decode('ascii'))))

結果

ランダムに動かしているだけなので、すぐに倒れてしまいます。

cartpole_random

Q学習で動かしてみる

ソースコード

## OpenAI Gym
import gym

## 表示用のパッケージ
import base64
import io
from gym.wrappers import Monitor
from IPython import display
from pyvirtualdisplay import Display

## その他必要なパッケージ
import numpy as np


## Brainクラス
class Brain:
	def __init__(self, num_states, list_state_range, list_state_reso, num_actions, gamma, r, lr):
		## パラメータをセット
		self.num_states = num_states
		self.list_state_range = list_state_range
		self.list_state_reso = list_state_reso
		self.num_actions = num_actions

		self.eps = 1.0  # for epsilon greedy algorithm
		self.gamma = gamma
		self.r = r
		self.lr = lr

		## Qテーブルを用意
		self.q_table = np.random.rand(np.prod(list_state_reso), num_actions)

	## ビンの配列(等差数列)を生成
	## 引数:最初の値、最後の値、要素数
	def bins(self, clip_min, clip_max, num):
		return np.linspace(clip_min, clip_max, num + 1)[1:-1]

	## 観測情報をQテーブル上のインデックスへ変換
	## 引数:観測情報
	def getStateIndex(self, obs):
		list_index = []
		for i in range(self.num_states):
			index = np.digitize(obs[i], bins=self.bins(self.list_state_range[i][0], self.list_state_range[i][1], self.list_state_reso[i]))	# obs[i]が所属するビンのインデックスを取得
			list_index.append(index)
		return sum([index*int(np.prod(self.list_state_reso[:i])) for i, index in enumerate(list_index)])	# 4次元のインデックスを1次元に変換して返す

	## Qテーブルを更新
	## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報
	def updateQtable(self, obs, action, reward, next_obs):
		q = self.q_table[self.getStateIndex(obs), action]
		next_q_max = np.max(self.q_table[self.getStateIndex(next_obs)])
		self.q_table[self.getStateIndex(obs), action] = q + self.lr*(reward + self.gamma*next_q_max - q)

	## アクションを決定
	## 引数:観測情報、訓練フラグ
	def getAction(self, obs, is_training):
		if is_training and np.random.rand() < self.eps:
			action = np.random.randint(self.num_actions)
		else:
			action = np.argmax(self.q_table[self.getStateIndex(obs)])
		## epsを更新
		if is_training and self.eps > 0.1:
			self.eps *= self.r
		return action


## Agentクラス
class Agent:
	def __init__(self, num_states, list_state_range, list_state_reso, num_actions, gamma, r, lr):
		## Brainを用意
		self.brain = Brain(num_states, list_state_range, list_state_reso, num_actions, gamma, r, lr)

	## Qテーブルを更新
	## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報 
	def updateQtable(self, obs, action, reward, next_obs):
		self.brain.updateQtable(obs, action, reward, next_obs)

	## アクションを決定
	## 引数:観測情報、訓練フラグ 
	def getAction(self, obs, is_training):
		action = self.brain.getAction(obs, is_training)
		return action


## Environmentクラス
class Environment:
	def __init__(self, num_episodes, max_step, gamma, r, lr):
		## パラメータをセット
		self.num_episodes = num_episodes
		self.max_step = max_step
		## Cart-Poleの環境を用意
		self.env = Monitor(gym.make('CartPole-v0'), './videos/', video_callable=(lambda ep: ep % 100 == 0), force=True)	# 100エピソードごとの動画を保存
		## Agentを用意
		num_states = self.env.observation_space.shape[0]	# position, velocity, angle, angular velocity
		list_state_range = []
		for i in range(num_states):
			list_state_range.append([self.env.observation_space.low[i], self.env.observation_space.high[i]])
		list_state_range[1] = [-3.0, 3.0]	# 適当に範囲を設定
		list_state_range[3] = [-0.5, 0.5]	# 適当に範囲を設定
		print("list_state_range = ", list_state_range)
		list_state_reso = [4, 4, 6, 6]	# 各状態量のQテーブルにおける解像度を適当に設定
		num_actions = self.env.action_space.n	# アクションは「右」、「左」の2つ

		self.agent = Agent(num_states, list_state_range, list_state_reso, num_actions, gamma, r, lr)
 
	## 訓練
	## 引数:なし
	def train(self):
		num_completed_episodes = 0
  
		## 指定するエピソード数でループ
		for episode in range(self.num_episodes):
			obs = self.env.reset()
			episode_reward = 0

			## 指定する最大ステップ数でループ
			for step in range(self.max_step):
				## アクションを決定
				action = self.agent.getAction(obs, is_training=True)
				## アクション後の状態を取得
				next_obs, _, is_done, _ = self.env.step(action)
				## 報酬を付与
				if is_done:
					if step < max_step - 1:
						reward = -100
					else:
						reward = 1
						num_completed_episodes += 1
				else:
					reward = 1
				episode_reward += reward
				## Qテーブルを更新
				self.agent.updateQtable(obs, action, reward, next_obs)
				## 次のステップへ
				obs = next_obs

				## 終了判定
				if is_done:
					print('{0} Episode: Finished after {1} time steps with reward {2}'.format(episode, step+1, episode_reward))
					break
		print("num_completed_episodes = ", num_completed_episodes)

	## 評価(Q値が最大となるアクションを選択して、Qテーブルの更新はしない)
	## 引数:なし
	def evaluate(self):
		obs = self.env.reset()
		
		for step in range(self.max_step):
			## Q値が最大となるアクションを選択
			action = self.agent.getAction(obs, is_training=False)
			## アクション後の状態を取得
			next_obs, _, is_done, _ = self.env.step(action)
			## 次のステップへ
			obs = next_obs

			## 終了判定
			if is_done:
				print('Evaluation: Finished after {} time steps'.format(step+1))
				break


## 描画用の関数
def show_video(env):
	env.reset()
	for frame in env.videos:
		print("frame = ", frame)
		video = io.open(frame[0], 'r+b').read()
		encoded = base64.b64encode(video)

		display.display(display.HTML(data="""
			<video controls>
			<source src="data:video/mp4;base64,{0}" type="video/mp4" />
			</video>
			""".format(encoded.decode('ascii')))
		)


## 仮想のディスプレイを用意
virtual_display = Display()
virtual_display.start()

## パラメータ
num_episodes = 500
max_step = 200
gamma = 0.9
r = 0.99
lr = 0.5

## 実行
cartpole_env = Environment(num_episodes, max_step, gamma, r, lr)
cartpole_env.train()
cartpole_env.evaluate()
show_video(cartpole_env.env)

結果

学習終了後は、上手くバランスをとって倒れなくなりました。

cartpole_qlearning

さいごに

Google Colaboratory上で、OpenAI GymのCart-PoleをQ学習で動かしてみました。

少しでも参考になれば幸いです。


以上です。

関連記事

Q学習の実装を理解した次は、Q学習と深層学習を組み合わせたDQN(Deep Q Network)を実装してみませんか?

【強化学習入門】DQNでCart-Poleを動かす(Google Colab+OpenAI Gym)

Ad.