本節(jié)包含本書中使用的實用函數(shù)和類的實現(xiàn)。
import collections
import inspect
from IPython import display
from torch import nn
from d2l import torch as d2l
import collections
import inspect
import tensorflow as tf
from IPython import display
from d2l import tensorflow as d2l
超參數(shù)。
@d2l.add_to_class(d2l.HyperParameters) #@save
def save_hyperparameters(self, ignore=[]):
"""Save function arguments into class attributes."""
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
for k, v in self.hparams.items():
setattr(self, k, v)
進度條。
@d2l.add_to_class(d2l.ProgressBoard) #@save
def draw(self, x, y, label, every_n=1):
Point = collections.namedtuple('Point', ['x', 'y'])
if not hasattr(self, 'raw_points'):
self.raw_points = collections.OrderedDict()
self.data = collections.OrderedDict()
if label not in self.raw_points:
self.raw_points[label] = []
self.data[label] = []
points = self.raw_points[label]
line = self.data[label]
points.append(Point(x, y))
if len(points) != every_n:
return
mean = lambda x: sum(x) / len(x)
line.append(Point(mean([p.x for p in points]),
mean([p.y for p in points])))
points.clear()
if not self.display:
return
d2l.use_svg_display()
if self.fig is None:
self.fig = d2l.plt.figure(figsize=self.figsize)
plt_lines, labels = [], []
for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
linestyle=ls, color=color)[0])
labels.append(k)
axes = self.axes if self.axes else d2l.plt.gca()
if self.xlim: axes.set_xlim(self.xlim)
if self.ylim: axes.set_ylim(self.ylim)
if not self.xlabel: self.xlabel = self.x
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
axes.set_xscale(self.xscale)
axes.set_yscale(self.yscale)
axes.legend(plt_lines, labels)
display.display(self.fig)
display.clear_output(wait=True)
添加 FrozenLake 環(huán)境
def frozen_lake(seed): #@save
# See https://www.gymlibrary.dev/environments/toy_text/frozen_lake/ to learn more about this env
# How to process env.P.items is adpated from https://sites.google.com/view/deep-rl-bootcamp/labs
env = gym.make('FrozenLake-v1', is_slippery=False)
env.seed(seed)
env.action_space.np_random.seed(seed)
env.action_space.seed(seed)
env_info = {}
env_info['desc'] = env.desc # 2D array specifying what each grid item means
env_info['num_states'] = env.nS # Number of observations/states or obs/state dim
env_info['num_actions'] = env.nA # Number of actions or action dim
# Define indices for (transition probability, nextstate, reward, done) tuple
env_info['trans_prob_idx'] = 0 # Index of transition probability entry
env_info['nextstate_idx'] = 1 # Index of next state entry
env_info['reward_idx'] = 2 # Index of reward entry
env_info['done_idx'] = 3 # Index of done entry
env_info['mdp'] = {}
env_info['env'] = env
for (s, others) in env.P.items():
# others(s) = {a0: [ (p(s'|s,a0), s', reward, done),...], a1:[...], ...}
for (a, pxrds) in others.items():
# pxrds is [(p1,next1,r1,d1),(p2,next2,r2,d2),..].
# e.g. [(0.3, 0, 0, False), (0.3, 0, 0, False), (0.3, 4, 1, False)]
env_info['mdp'][(s,a)] = pxrds
return env_info
創(chuàng)造環(huán)境
示值函數(shù)
def show_value_function_progress(env_desc, V, pi): #@save
# This function visualizes how value and policy changes over time.
# V: [num_iters, num_states]
# pi: [num_iters, num_states]
# How to visualize value function is adapted (but changed) from: https://sites.google.com/view/deep-rl-bootcamp/labs
num_iters = V.shape[0]
fig, ax = plt.subplots(figsize=(15, 15))
for k in range(V.shape[0]):
plt.subplot(4, 4, k + 1)
plt.imshow(V[k].reshape(4,4), cmap="bone")
ax = plt.gca()
ax.set_xticks(np.arange(0, 5)-.5, minor=True)
ax.set_yticks(np.arange(0, 5)-.5, minor=True)
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
ax.set_xticks([])
ax.set_yticks([])
# LEFT action: 0, DOWN action: 1
# RIGHT action: 2, UP action: 3
action2dxdy = {0:(-.25, 0),1: (0, .25),
2:(0.25, 0),3: (-.25, 0)}
for y in range(4):
for x in range(4):
action = pi[k].reshape(4,4)[y, x]
dx, dy = action2dxdy[action]
if env_desc[y,x].decode() == 'H':
ax.text(x, y, str(env_desc[y,x].decode()),
ha="center", va="center", color="y",
size=20, fontweight='bold')
elif env_desc[y,x].decode() == 'G':
ax.text(x, y, str(env_desc[y,x].decode()),
ha="center", va="center", color="w",
size=20, fontweight='bold')
else:
ax.text(x, y, str(env_desc[y,x].decode()),
ha="center", va="center", color="g",
size=15, fontweight='bold')
# No arrow for cells with G and H labels
if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':
ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)
ax.set_title("Step = " + str(k + 1),
評論
查看更多