基于TensorFlow的神经网络库
文件列表(压缩包大小 8.81M)
免费
概述
Sonnet是在TensorFlow 2之上构建的库,旨在为机器学习研究进行简单可组合的抽象。
Sonnet是由DeepMind的研究人员设计构建的,它可以构造用于不同目的的神经网络(无/监督学习,强化学习等),我们发现这对我们的组织能够成功的抽象,你的也可以这样!
更具体地说,Sonnet提供了一个简单而强大的编程模型,该模型围绕一个概念:snt.Module
。 模块可以保存对在用户输入上应用某些功能的参数,其他模块和方法的引用。 Sonnet附带了许多预定义的模块(例如snt.Linear
,snt.Conv2D
,snt.BatchNorm
)和一些预定义的模块网络(例如snt.nets.MLP
),同时也鼓励用户构建自己的模块。
与许多框架不同,Sonnet在如何使用模块方面毫无保留。 模块设计为自包含的并且彼此之间完全解耦。Sonnet没有附带训练框架,因此鼓励用户自己构建或采用他人构建的框架。
Sonnet的设计也易于理解,我们的代码清晰且重点突出。 在选择默认值(例如初始参数值的默认值)的地方,会尽量指出原因。
尝试Sonnet的最简单方法是使用Google Colab,它提供了一个免费的Python笔记本连接到GPU或TPU。
开始安装TensorFlow 2.0和Sonnet 2:
$ pip install tensorflow tensorflow-probability
$ pip install dm-sonnet
运行以下命令验证安装的东西:
import tensorflow as tf
import sonnet as snt
print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))
使用现有模块
Sonnet随附了许多可以轻松使用的内置模块。 例如,要定义一个MLP,我们可以使用snt.Sequential
模块调用模块序列,将给定模块的输出作为下一个模块的输入。 我们可以使用snt.Linear
和tf.nn.relu
实际定义我们的计算:
mlp = snt.Sequential([
snt.Linear(1024),
tf.nn.relu,
snt.Linear(10),
])
要使用我们的模块,需要“调用”它。 顺序模块(和大多数模块)定义了__call__
方法,可以按名称调用它们:
logits = mlp(tf.random.normal([batch_size, input_size]))
为模块请求所有参数也是很常见的。 Sonnet中的大多数模块在第一次使用某些输入调用它们时都会创建它们的参数(因为在大多数情况下,参数的形状是输入的函数)。 Sonnet模块提供了两个用于访问参数的属性。
variables
属性返回给定模块引用的所有tf.Variables
:
all_variables = mlp.variables
值得注意的是tf.Variables
不只是用于模型的参数。 例如,它们用于在snt.BatchNorm
中使用的度量标准中保持状态。 在大多数情况下,用户检索模块变量以将其传递给优化器以进行更新。 在这种情况下,不可训练变量通常不应该在该列表中,因为它们是通过不同的机制更新的。 TensorFlow具有内置机制,可将变量标记为“可训练的”(模型参数)与“不可训练的”(其他变量)。 Sonnet提供了一种从模块收集所有可训练变量的机制,这可能是你想要传递给优化器的东西:
model_parameters = mlp.trainable_variables
构建自己的模块
Sonnet强烈建议用户将snt.Module
子类化以定义自己的模块。 让我们从创建一个名为MyLinear
的简单线性层开始:
class MyLinear(snt.Module):
def __init__(self, output_size, name=None):
super(MyLinear, self).__init__(name=name)
self.output_size = output_size
@snt.once
def _initialize(self, x):
initial_w = tf.random.normal([x.shape[1], self.output_size])
self.w = tf.Variable(initial_w, name="w")
self.b = tf.Variable(tf.zeros([self.output_size]), name="b")
def __call__(self, x):
self._initialize(x)
return tf.matmul(x, self.w) + self.b
使用这个模块很简单:
mod = MyLinear(32)
mod(tf.ones([batch_size, input_size]))
通过子类化snt.Module
,可以免费获得许多不错的属性。 例如,__repr__
默认实现显示了构造函数参数(对于调试和自省非常有用):
>>> print(repr(mod))
MyLinear(output_size=10)
还可以获得variables
和trainable_variables
属性:
>>> mod.variables
(<tf.Variable 'my_linear/b:0' shape=(10,) ...)>,
<tf.Variable 'my_linear/w:0' shape=(1, 10) ...)>)
你可能会在上面的变量上注意到my_linear
前缀。 这是因为每当调用方法时,Sonnet模块也会进入模块名称范围。 通过输入模块名称范围,我们为TensorBoard之类的工具提供了更为有用的图形(例如,发生在my_linear中的所有操作都将在一个名为my_linear的组中)。
此外,你的模块现在将支持TensorFlow检查点和保存的模型,这是稍后介绍的高级功能。
Sonnet支持多种序列化格式。 我们支持的最简单的格式是Python的pickle
,并且对所有内置模块进行了测试,以确保可以在同一Python进程中通过pickle保存/加载它们。 总的来说,我们不鼓励使用pickle,因为TensorFlow的许多部分都不能很好地支持它,并且根据我们的经验来说它可能不够稳定。
参考:https://www.tensorflow.org/alpha/guide/checkpoints
TensorFlow Checkpointing可用于在训练期间定期保存参数值。 如果程序崩溃或停止,它将保存训练进度。 Sonnet旨在与TensorFlow Checkpointing完美协作:
checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)
my_module = create_my_sonnet_module() # Can be anything extending snt.Module.
# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf.train.Checkpoint(module=my_module)
# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
checkpoint.restore(latest)
for step_num in range(num_steps):
train(my_module)
# During training we will occasionally save the values of weights. Note that
# this is a blocking call and can be slow (typically we are writing to the
# slowest storage on the machine). If you have a more reliable setup it might be
# appropriate to save less frequently.
if step_num and not step_num % 1000:
checkpoint.save(save_prefix)
# Make sure to save your final values!!
checkpoint.save(save_prefix)
参考:https://www.tensorflow.org/alpha/guide/saved_model
TensorFlow保存的模型可用于保存与Python源分离的网络副本。通过保存描述计算的TensorFlow图和包含权重值的检查点来启用此功能。
要创建保存的模型,要做的第一件事是创建要保存的snt.Module
:
my_module = snt.nets.MLP([1024, 1024, 10])
my_module(tf.ones([1, input_size]))
接下来,我们需要创建另一个模块来描述我们要导出的模型的特定部分。 我们建议这样做(而不是就地修改原始模型),可以对实际输出的内容进行细粒度的控制。 这通常可以避免创建非常大的已保存模型,这样一来,只共享模型中要共享的部分即可(例如,你只想共享GAN的生成器,同时使鉴别器不公开)。
@tf.function(input_signature=[tf.TensorSpec([None, input_size])])
def inference(x):
return my_module(x)
to_save = snt.Module()
to_save.inference = inference
to_save.all_variables = list(my_module.variables)
tf.saved_model.save(to_save, "/tmp/example_saved_model")
现在,我们在/tmp/example_saved_model
文件夹中有一个保存的模型:
$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098 14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables
加载此模型非常简单,无需在构建已保存模型的任何Python代码的情况下,即可在其他计算机上完成该加载:
loaded = tf.saved_model.load("/tmp/example_saved_model")
# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded.inference(tf.ones([1, input_size]))
# The all_variables property can be used to retrieve the restored variables.
assert len(loaded.all_variables) > 0
注意,加载的对象不是Sonnet模块,它是一个容器对象,具有我们在上一个块中添加的特定方法(例如inference
)和属性(例如all_variables
)。
示例:https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb
Sonnet支持使用自定义TensorFlow分布策略进行分布式培训。
Sonnet与使用tf.keras
进行的分布式训练之间的主要区别在于,在分配策略下运行时,Sonnet模块和优化器的行为不会有所不同(例如,我们不会平均梯度或同步批处理规范统计信息)。 我们认为,用户应该完全控制他们的训练的这些方面,不应将其拷到库中。 这里需要权衡的是,你需要在训练脚本中实现这些功能(通常,只有两行代码才能在应用优化程序之前全部减小梯度),或者在具有明确分布意识的模块中进行交换(例如snt.distribute.CrossReplicaBatchNorm
)。
我们的分布式Cifar-10示例演示了如何使用Sonnet进行多GPU训练。
如果遇到文件不能下载或其他产品问题,请添加管理员微信:ligongku001,并备注:产品反馈
评论(0)