horovod tensorflow 分布式多gpu_horovod.tensorflow_青盏的博客-程序员秘密

技术标签: DL tools  

概念
rank is your index within the entire ring, local_rank is your index within your node. For example, you have 4 nodes and 4 GPUs each node, so you spawn 16 workers. Every worker will have a rank [0, 15], and every worker will have a local_rank [0, 3]. You use local_rank for GPU pinning because there’s typically one GPU available on the node per process. It wouldn’t make sense to use rank here because rank could be 10, but you only have 4 GPUs so there is no GPU 10.

# 在其他import前引入
try:
    import horovod.tensorflow as hvd
    hvd.init()
except Exception as e:
    hvd = None
    print('no horovod')

# 打印信息
if hvd:
    tf.logging.info('Total workers: {}, local workers: {}'.format(
        hvd.size(), hvd.local_size()))
    tf.logging.info('Global rank: {}, local rank: {}'.format(
        hvd.rank(), hvd.local_rank()))

# 数据集读取配置:对数据集进行分片, 不同进程读取不同子集。
d = tf.data.TFRecordDataset(input_file)
if is_training:
    if hvd is not None:
        d = d.shard(hvd.size(), hvd.rank())
    d = d.shuffle(buffer_size=100)
    d = d.repeat()

# 加载权重配置:只对第一个rank载入权重
if init_checkpoint and is_training and (hvd is None or hvd.rank()==0):
    for init_file in init_checkpoint.split(","):
        assignment_map, tmp_init_map = get_assignment_map_from_checkpoint(tvars, init_file, extra_load_var)
        tf.train.init_from_checkpoint(init_file, assignment_map)
        initialized_variable_names.update(tmp_init_map)

# 学习率调整:
if hvd:
    learning_rate = learning_rate * hvd.size()

# 分布式优化器配置:使用 ring-allreduce 平均梯度
if hvd is not None:
    # we enable compression only for fp16
    from horovod.tensorflow.compression import Compression
    if use_fp16:
        compression = Compression.fp16
    else:
        compression = Compression.none

    optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True,
                                         compression=compression)

# 配置每个进程模型迭代次数
if FLAGS.do_train:
    # train_examples = processor.get_train_examples(FLAGS.data_dir, FLAGS.img_dir)
    num_train_steps = int(
        train_num / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        # len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    if hvd:
        num_train_steps = num_train_steps // hvd.size()

model_fn = model_fn_builder(
    bert_config=bert_config,
    num_labels=len(label_list),
    init_checkpoint=FLAGS.init_checkpoint,
    learning_rate=FLAGS.learning_rate,
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps)

# GPU config GPU配置:使用local rank分配当前机器上当前进程可视gpu
run_config = tf.ConfigProto()
# train_params.get('gpu_allow_growth', False)
run_config.gpu_options.allow_growth = True
run_config.allow_soft_placement = True

if hvd:
    run_config.gpu_options.visible_device_list = str(hvd.local_rank())

if FLAGS.use_xla:
    run_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

# checkpoint配置:只对第一个保存模型
save_checkpoints_steps = FLAGS.save_checkpoints_steps if hvd is None or hvd.rank() == 0 else None
estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=FLAGS.output_dir,
    config=tf.estimator.RunConfig(
        save_checkpoints_steps=save_checkpoints_steps,
        save_checkpoints_secs=None,
        keep_checkpoint_every_n_hours=2,
        log_step_count_steps=400,
        session_config=run_config))


# 模型训练hook配置:将变量从第一个流程向其他流程传播,以实现一致性初始化。
if FLAGS.do_train and hvd is not None:
    training_hook = [hvd.BroadcastGlobalVariablesHook(0)]
else:
    training_hook = []
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps,
                hooks=training_hook)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_16234613/article/details/96186398

智能推荐

java-Cloneable克隆对象内容_vv_wisher的博客-程序员秘密

有时候需要实体中的字段内容全部复制到一个新的实体中,BoardTest old= NewBoardTest();BoardTest new = old;但是当new = old 时,两个对象是同一地址,达不到复制的目的。可以通过克隆的方式,完成不同对象的内容复制。一个对象直接克隆为另一个对象时,会生成新的地址。1、实体实现Cloneablepublic cla...

尚学堂python开发工具_尚学堂百战程序员分享:Python的数据模型_weixin_39739395的博客-程序员秘密

接触 Python 有一段时间了,总结了很多关于Python的数据模型,但是到现在也没怎么用 Python 写过一些有用的东西。基础虽然还行,但更深入的就不怎么了解了。于是打算看一下《流畅的Python》。首先是数据模型,主要是 Python 的魔术方法(特殊方法),它们以双下划线开头和结束,能让我们自己写的类拥有类似Python内置对象那样的属性和方法。首先出场的是getitem和len。有了g...

springBoot学习笔记(完结)_springboot笔记_Java小墩墩的博客-程序员秘密

Spring是为了解决企业级应用开发的复杂性而创建的,简化开发。为了降低Java开发的复杂性,Spring采用了以下4种关键策略:1、基于POJO的轻量级和最小侵入性编程,所有东西都是bean;2、通过IOC,依赖注入(DI)和面向接口实现松耦合;3、基于切面(AOP)和惯例进行声明式编程;4、通过切面和模版减少样式代码,RedisTemplate,xxxTemplate;什么是SpringBoot呢,就是一个javaweb的开发框架,和SpringMVC类似,对比其他javaweb框架的好处,官方说是简化

Sql Server:多行合并成一行,并做分组统计_weixin_30648587的博客-程序员秘密

--创建test表,插入数据CREATETABLEtest(codevarchar(50),[values]varchar(10),[count]int)INSERTtestSELECT'001','aa',1UNIONALLSELECT'001','bb',2UNIONALLSELECT'002','aaa',4UNIONALLSELEC...

Linux的MySQL服务启动失败:Failed to start SYSV: MySQL databas...._linux failed to start sysv: mysql database server._ReflectMirroring的博客-程序员秘密

报错[[email protected] ~]# mysql -uroot -pEnter password: ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/var/lib/mysql/mysql.sock' (111)[[email protected] ~]# service mysqld startStarting mysqld (via systemctl): Job for mysqld.servi

python3.9安装gdal库(便捷版)_python3.9.7安装gdal库_橙色的小太阳的博客-程序员秘密

GDAL(Geospatial Data Abstraction Library)是一个在X/MIT许可协议下的开源栅格空间数据转换库。它利用抽象数据模型来表达所支持的各种文件格式。它还有一系列命令行工具来进行数据转换和处理。在安装gdal库(pip install gdal)时碰到了一个很难缠的bug:error: Microsoft Visual C++ 14.0 is required. Get it with "Build Tools for Visual Studio": https://

随便推点

计算机电缆的铜丝和铜带的区别,请问铜带屏蔽和铜丝屏蔽那个好? - 无图版_正直博的博客-程序员秘密

ashui1986 --- 2010-02-26 11:15:481各位大侠请问铜带屏蔽和铜丝屏蔽那个好??要是铜带屏蔽效果好,那为什么标准规定是 铜丝屏蔽这样的屏蔽形式不是单独使用 而采用缠绕+反向扎带这样的屏蔽形式??望各位大侠能够给予详细的赐教 谢谢!!![yyyfff 在 2010-2-26 11:48:00 编辑过]yyyfff --- 2010-02-26 11:54:442按屏蔽...

Zenmap扫描udp端口太慢_udp端口扫描速度慢_Yangzf0628的博客-程序员秘密

Zenmap扫描udp端口太慢https://blog.csdn.net/weixin_39525300/article/details/119470968使用Zenmap扫描端口很慢,一秒才一个icmp包,受上面文章启发,看到服务端/proc/sys/net/ipv4目录下,存在icmp_ratelimit和icmp_ratemask,将两个都设置为不限制后,速度很快,十几二十秒即可完成扫描...

QT共享内存_yangluoning的博客-程序员秘密

使用创建 QSharedMemory 对象调用 create 成员函数分配共享内存,或者 attach 附加到已创建的共享内存使用内存 (注意lock、unlock)下面的例子很简单,不用多说。编译之后,运行3个实例。第一个创建共享内存,其他的读取共享内存:#include /QCoreApplication>#include /QSharedMemory>int m

前端网络基础 - axios源码分析_axiosinstance_伏城之外的博客-程序员秘密

前端网络基础 - axios使用_qfc_128220的博客-程序员秘密在上一节中,我们分析了axios的基本使用,其中有很多让人一时无法参悟透奥妙的设计。我们来逐一通过源码解析下。目录axios为什么既可以作为函数发送AJAX,也可以作为对象调用get,post等方法发送AJAX?Axios类axios默认的axios函数和axios.create新建的axios函数的差别在哪?axios拦截器是如何实现的axios执行流程axios取消请求简略版axios实

【安卓学习之第三方库】 消息推送之极光推送_笔夏的博客-程序员秘密

█ 【安卓学习之第三方库】 消息推送之极光推送█ 相关文章:-  ● 【安卓学习之第三方库】库的使用2-jar类库的使用(以dom4j为例)和升级(以极光推送为例)█ 读前说明:-  ● 本文通过学习别人写demo,学习一些课件,参考一些博客,’学习相关知识,如果涉及侵权请告知 ● 本文只简单罗列相关的代码实现过程 ● 涉及到的逻辑以及说明也只是简单介绍,主要当做笔记,了解过程而已█ 极光后台:-  ● 两年没有进极光后台,发现界面发生了很大变化,功能也有些变化,这边就整理下,

python爬虫学习之用Python抢火车票的简单小程序_Python新手学习之家的博客-程序员秘密

利用Python制作自动抢火车票小程序,过年再也不要担心没票了!前言每次过年很多人都会因为抢不到火车票而回不了家,所以小编利用Python写了一个自动抢火车票的工具,希望大家能抢到火车票,回家过个好年!话不多说,直接上代码:'''在学习过程中有什么不懂得可以加我的python学习交流扣扣qun,934109170群里有不错的学习视频教程、开发工具与电子书籍。与你分享p...

推荐文章

热门文章

相关标签