Skip to content

Commit 47c3325

Browse files
committed
[Week 08] Update & clean up the notebook
1 parent 9137d39 commit 47c3325

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

week08_pomdp/practice_pytorch.ipynb

+54-38
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,25 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"import sys\n",
10-
"if 'google.colab' in sys.modules:\n",
11-
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/atari_util.py\n",
12-
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/env_pool.py\n",
13-
"\n",
14-
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
15-
" !touch .setup_complete\n",
16-
"# If you are running on a server, launch xvfb to record game videos\n",
17-
"# Please make sure you have xvfb installed\n",
18-
"import os\n",
19-
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
20-
" !bash ../xvfb start\n",
21-
" os.environ['DISPLAY'] = ':1'"
9+
"import sys, os\n",
10+
"if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
11+
" # Install xvfb and our launcher script for it\n",
12+
" !apt-get install -y xvfb\n",
13+
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n",
14+
"\n",
15+
" !pip install gym[atari,accept-rom-license]\n",
16+
"\n",
17+
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n",
18+
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n",
19+
"\n",
20+
" !touch .setup_complete\n",
21+
"\n",
22+
"# This code creates a virtual display to draw game images on.\n",
23+
"# It will have no effect if your machine has a monitor.\n",
24+
"import os\n",
25+
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
26+
" !bash ../xvfb start\n",
27+
" os.environ['DISPLAY'] = ':1'"
2228
]
2329
},
2430
{
@@ -53,7 +59,6 @@
5359
"name": "stdout",
5460
"output_type": "stream",
5561
"text": [
56-
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
5762
"Observation shape: (1, 42, 42)\n",
5863
"Num actions: 14\n",
5964
"Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n"
@@ -70,6 +75,7 @@
7075
" env = PreprocessAtari(env, height=42, width=42,\n",
7176
" crop=lambda img: img[60:-30, 15:],\n",
7277
" color=False, n_frames=1)\n",
78+
" env.metadata['render_fps'] = 30\n",
7379
" return env\n",
7480
"\n",
7581
"\n",
@@ -143,7 +149,7 @@
143149
"\n",
144150
"Let's design another agent that has a recurrent neural net memory to solve this. Here's a sketch.\n",
145151
"\n",
146-
"![img](img1.jpg)\n"
152+
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img1.jpg)\n"
147153
]
148154
},
149155
{
@@ -204,13 +210,15 @@
204210
" return new_state, (logits, state_value)\n",
205211
"\n",
206212
" def get_initial_state(self, batch_size):\n",
207-
" \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
208-
" return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n",
213+
" \"\"\"Return the agent memory state at the beginning of the game. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
214+
" h0 = torch.zeros((batch_size, 128))\n",
215+
" c0 = torch.zeros((batch_size, 128))\n",
216+
" return h0, c0\n",
209217
"\n",
210218
" def sample_actions(self, agent_outputs):\n",
211219
" \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n",
212220
" logits, state_values = agent_outputs\n",
213-
" probs = F.softmax(logits)\n",
221+
" probs = F.softmax(logits, dim=-1)\n",
214222
" return torch.multinomial(probs, 1)[:, 0].data.numpy()\n",
215223
"\n",
216224
" def step(self, prev_state, obs_t):\n",
@@ -258,11 +266,13 @@
258266
"metadata": {},
259267
"outputs": [],
260268
"source": [
269+
"import tqdm\n",
270+
"\n",
261271
"def evaluate(agent, env, n_games=1):\n",
262272
" \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n",
263273
"\n",
264274
" game_rewards = []\n",
265-
" for _ in range(n_games):\n",
275+
" for _ in tqdm.notebook.trange(n_games):\n",
266276
" # initial observation and memory\n",
267277
" observation = env.reset()\n",
268278
" prev_memories = agent.get_initial_state(1)\n",
@@ -292,7 +302,7 @@
292302
"source": [
293303
"import gym.wrappers\n",
294304
"\n",
295-
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
305+
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
296306
" rewards = evaluate(agent, env_monitor, n_games=3)\n",
297307
"\n",
298308
"print(rewards)"
@@ -336,7 +346,7 @@
336346
"### Training on parallel games\n",
337347
"\n",
338348
"We introduce a class called EnvPool - it's a tool that handles multiple environments for you. Here's how it works:\n",
339-
"![img](img2.jpg)"
349+
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img2.jpg)"
340350
]
341351
},
342352
{
@@ -354,7 +364,7 @@
354364
"metadata": {},
355365
"source": [
356366
"We gonna train our agent on a thing called __rollouts:__\n",
357-
"![img](img3.jpg)\n",
367+
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img3.jpg)\n",
358368
"\n",
359369
"A rollout is just a sequence of T observations, actions and rewards that agent took consequently.\n",
360370
"* First __s0__ is not necessarily initial state for the environment\n",
@@ -446,7 +456,7 @@
446456
"source": [
447457
"def to_one_hot(y, n_dims=None):\n",
448458
" \"\"\" Take an integer tensor and convert it to 1-hot matrix. \"\"\"\n",
449-
" y_tensor = y.to(dtype=torch.int64).view(-1, 1)\n",
459+
" y_tensor = y.to(dtype=torch.int64).reshape(-1, 1)\n",
450460
" n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n",
451461
" y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n",
452462
" return y_one_hot"
@@ -472,7 +482,7 @@
472482
" states = torch.tensor(np.asarray(states), dtype=torch.float32)\n",
473483
" actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n",
474484
" rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n",
475-
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n",
485+
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.bool) # shape: [batch_size, time]\n",
476486
" rollout_length = rewards.shape[1] - 1\n",
477487
"\n",
478488
" # predict logits, probas and log-probas using an agent.\n",
@@ -483,7 +493,7 @@
483493
" for t in range(rewards.shape[1]):\n",
484494
" obs_t = states[:, t]\n",
485495
"\n",
486-
" # use agent to comute logits_t and state values_t.\n",
496+
" # use agent to compute logits_t and state values_t.\n",
487497
" # append them to logits and state_values array\n",
488498
"\n",
489499
" memory, (logits_t, values_t) = <YOUR CODE>\n",
@@ -521,9 +531,10 @@
521531
" V_next = state_values[:, t + 1].detach() # next state values\n",
522532
" # log-probability of a_t in s_t\n",
523533
" logpi_a_s_t = logprobas_for_actions[:, t]\n",
534+
" is_not_done_t = is_not_done[:, t]\n",
524535
"\n",
525536
" # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n",
526-
" cumulative_returns = G_t = r_t + gamma * cumulative_returns\n",
537+
" cumulative_returns = G_t = r_t + torch.where(is_not_done_t, gamma * cumulative_returns, 0)\n",
527538
"\n",
528539
" # Compute temporal difference error (MSE for V(s))\n",
529540
" value_loss += <YOUR CODE>\n",
@@ -579,7 +590,6 @@
579590
"outputs": [],
580591
"source": [
581592
"from IPython.display import clear_output\n",
582-
"from tqdm import trange\n",
583593
"from pandas import DataFrame\n",
584594
"moving_average = lambda x, **kw: DataFrame(\n",
585595
" {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n",
@@ -593,21 +603,27 @@
593603
"metadata": {},
594604
"outputs": [],
595605
"source": [
596-
"for i in trange(15000):\n",
606+
"log_every = 100\n",
607+
"\n",
608+
"for i in tqdm.trange(15000):\n",
609+
" # tqdm.notebook.tqdm is not trivial to use here because clear_output(True)\n",
610+
" # also removes the tqdm widget\n",
597611
"\n",
598612
" memory = list(pool.prev_memory_states)\n",
599-
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n",
600-
" 10)\n",
601-
" train_on_rollout(rollout_obs, rollout_actions,\n",
602-
" rollout_rewards, rollout_mask, memory)\n",
613+
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n",
614+
" train_on_rollout(rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory)\n",
603615
"\n",
604-
" if i % 100 == 0:\n",
616+
" if i % log_every == 0:\n",
605617
" rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n",
606618
" clear_output(True)\n",
607-
" plt.plot(rewards_history, label='rewards')\n",
608-
" plt.plot(moving_average(np.array(rewards_history),\n",
609-
" span=10), label='rewards ewma@10')\n",
619+
" plt.plot(\n",
620+
" np.arange(len(rewards_history)) * log_every,\n",
621+
" rewards_history, label='rewards')\n",
622+
" plt.plot(\n",
623+
" np.arange(len(rewards_history)) * log_every,\n",
624+
" moving_average(np.array(rewards_history), span=10), label='rewards ewma@10')\n",
610625
" plt.legend()\n",
626+
" plt.grid()\n",
611627
" plt.show()\n",
612628
" if rewards_history[-1] >= 10000:\n",
613629
" print(\"Your agent has just passed the minimum homework threshold\")\n",
@@ -628,7 +644,7 @@
628644
"Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n",
629645
"\n",
630646
"If it does, the culprit is likely:\n",
631-
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n",
647+
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot \\log p(a_i) $\n",
632648
"* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n",
633649
"* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n",
634650
"* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n",
@@ -651,7 +667,7 @@
651667
"source": [
652668
"import gym.wrappers\n",
653669
"\n",
654-
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
670+
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
655671
" final_rewards = evaluate(agent, env_monitor, n_games=20)\n",
656672
"\n",
657673
"print(\"Final mean reward\", np.mean(final_rewards))"

0 commit comments

Comments
 (0)