12月29, 2019

tensorflow fine_tune技巧

由于很多时候我们在一个新的网络中只会用到一个已训练模型的部分参数,即迁移学习。 那么,如何加载已训练模型的部分参数到当前网络。

一、 前网络加载已训练模型相同name scope的变量

方法1. 手动构建与预训练一样的部分图 将需要fine tune的变量的name scope命名为与模型中的name scope相同,然后使用如下代码将模型参数加载到当前网络。

tf.train.Saver([var for var in tf.global_variables() if var.name.startswith('train')]) \
            .restore(sess,' D:\\tuxiang\\hhh\\my_test_model-1')

二、当前网络与已训练模型的name scope不一致的情况

例如:将训练模型name scope ‘a’的值赋给当前网络的name scope ‘m/b/t’。

方法1:改写已保存模型的name scope,使其与目标变量的name scope一致;或者将需要加载到当前网络的参数选取出来写入一个新的模型,然后直接加载就ok. name scope改写代码,也可进行name scope删选,剔除不需要加载的变量。

import os
import tensorflow as tf
import numpy as np

# 新模型的存储地址
new_checkpoint_path = 'D:\\tuxiang\\hhh\\h\\'
# 旧模型的存储地址
checkpoint_path = 'D:\\tuxiang\\hhh\\my_test_model-1'
# 添加的name scope
add_prefix  = 'main/'
if not os.path.exists(new_checkpoint_path):
    os.makedirs(new_checkpoint_path)
with tf.Session() as sess:
    new_var_list = []  # 新建一个空列表存储更新后的Variable变量
    for var_name, _ in tf.contrib.framework.list_variables(checkpoint_path):  # 得到checkpoint文件中所有的参数(名字,形状)元组
        var = tf.contrib.framework.load_variable(checkpoint_path, var_name)  # 得到上述参数的值
        # var_name为变量的name scope,是一个字符串,可以进行改写
        # var 是该name scope对应的值
        print(var_name,var)
        new_name = var_name
        new_name = add_prefix + new_name  # 在这里加入了名称前缀,大家可以自由地作修改
        # 除了修改参数名称,还可以修改参数值(var)
        print('Renaming %s to %s.' % (var_name, new_name))
        renamed_var = tf.Variable(var, name=new_name)  # 使用加入前缀的新名称重新构造了参数
        new_var_list.append(renamed_var)  # 把赋予新名称的参数加入空列表
    print('starting to write new checkpoint !')
    saver = tf.train.Saver(var_list=new_var_list)  # 构造一个保存器
    sess.run(tf.global_variables_initializer())  # 初始化一下参数(这一步必做)
    model_name = 'deeplab_resnet_altered'  # 构造一个保存的模型名称
    checkpoint_path = os.path.join(new_checkpoint_path, model_name)  # 构造一下保存路径
    saver.save(sess, checkpoint_path)  # 直接进行保存
    print("done !")

方法2:直接通过tf.train.import_meta_graph()和saver.restore()将模型的所有参数加载到当前图中,然后再使用 sess.run(name scope)和sess.run(tf.assign())取出模型所在的name scope的数值赋给网络中需要fine tune的变量.

使用此方法会将预训练模型的所有参数和图加载进来并在保存的时候与当前网络一起保存,使参数更加庞大,必须在训练结束后定义需要保存的变量,避免保存所有参数。

with tf.variable_scope('train/x'):
    w1 = tf.get_variable('w1', shape = [2])  
    w2 = tf.get_variable( name='w2',shape=[2])  
    w3 = tf.get_variable( name='w3',shape=[2])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    print(sess.run('train/x/w3:0'))
    saver =tf.train.import_meta_graph('D:\\tuxiang\\hhh\\my_test_model-1.meta')
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    data = sess.run('train/w1:0')

    print(data)  
    sess.run(tf.assign(w1,data))
    print(sess.run('train/x/w1:0'))

使用TensorFlow-Slim获取局部参数

exclude = ['weight1','weight2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)

init_fn(sess)

三、固定加载进来的参数

对于某些fine tune的参数,如果希望将这些参数固定住,即不训练,可以通过在定义变量的时候设置trainable=False;也可以在训练过程中的minimize()中添加需要训练的参数,则未添加的参数会固定住。

用var_list = tf.contrib.framework.get_variables(scope_name)获取指定scope_name下的变量, 然后optimizer.minimize()时传入指定var_list即可。

train_op = tf.train.GradientDescentOptimizer.minimize(loss,var_list=var_list)

参考文献: tensorflow fine_tune已训练模型的部分参数

本文链接:http://57km.cc/post/tensorflow fine_tune.html

-- EOF --

Comments