做住宿的网站扬州网络优化推广
加载模型的前几层拼接自己构建的层进行训练
注意这里我们使用了nets.inception.inception_v3_base来进行网络模型的部分恢复,因为nets.inception.inception_v3_base中可以指定final_endpoint参数进行网络的末尾层指定,然后通过在saver的restore函数中进行参数的设定来确保那些权值进行恢复,那些不需要进行恢复!
train.py
#-*-coding=utf-8-*-
from PIL import Image
import os
import os.path
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
import inception_resnet_v2
import img_convertheight = 299
width = 299
channels = 3
num_classes=1001def convert(dir):filelists=os.listdir(dir)arr_col=[]for file in filelists:file_path=os.path.join(dir,file)img=Image.open(file_path).resize((299,299)).convert("RGB")r,g,b=img.split()r_arr=np.array(r)g_arr=np.array(g)b_arr=np.array(b)img_arr=np.concatenate((r_arr,g_arr,b_arr))result=img_arr.reshape((299,299,3))arr_col.append(result)return arr_col
def convert_3_2_4_dims(arr_):ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))for i in range(len(arr_)):ret[i,:,:,:]=arr_[i]return ret
if __name__=="__main__":o_dir="E:/test"num_classes=182batch_size=3epoches=2X = tf.placeholder(tf.float32, shape=[None, height, width, channels])y = tf.placeholder(tf.float32,shape=[None,182])with slim.arg_scope(nets.inception.inception_v3_arg_scope()):logits,end_points_ = nets.inception.inception_v3_base(X,final_endpoint='Mixed_7c')variables_to_restore=slim.get_variables_to_restore()shape=logits.get_shape().as_list()dim=1for d in shape[1:]:dim*=dfc_=tf.reshape(logits,[-1,dim])fc0_weights=tf.get_variable(name="fc0_weights",shape=(dim,182),initializer=tf.contrib.layers.xavier_initializer())fc0_biases=tf.get_variable(name="fc0_biases",shape=(182),initializer=tf.contrib.layers.xavier_initializer())logits_=tf.nn.bias_add(tf.matmul(fc_,fc0_weights),fc0_biases)predictions=tf.nn.softmax(logits_)#cross_entropy = -tf.reduce_sum(y*tf.log(predictions)) cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=logits_))#cross_entropy_mean=tf.reduce_mean(cross_entropy)train_step=tf.train.GradientDescentOptimizer(1e-6).minimize(cross_entropy)correct_pred=tf.equal(tf.argmax(y,1),tf.argmax(predictions,1))#acc=tf.reduce_sum(tf.cast(correct_pred,tf.float32))accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))with tf.Session() as sess:batches=img_convert.data_lrn(img_convert.load_data(o_dir,num_classes,batch_size))sess.run(tf.global_variables_initializer())saver=tf.train.Saver(variables_to_restore)saver.restore(sess,os.path.join("E:\\","inception_v3.ckpt"))for epoch in range(epoches):for batch in batches:sess.run(train_step,feed_dict={X:batch[0],y:batch[1]})acc=sess.run(accuracy,feed_dict={X:batches[0][0],y:batches[1][1]})print(acc)print("Done")
img_convert.py
#coding=utf-8
from PIL import Image
import os
import os.path
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
import inception_resnet_v2def convert(dir):filelists=os.listdir(dir)arr_col=[]for file in filelists:file_path=os.path.join(dir,file)img=Image.open(file_path).resize((299,299)).convert("RGB")r,g,b=img.split()r_arr=np.array(r)g_arr=np.array(g)b_arr=np.array(b)img_arr=np.concatenate((r_arr,g_arr,b_arr))result=img_arr.reshape((299,299,3))arr_col.append(result)return arr_col
def convert_3_2_4_dims(arr_):ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))for i in range(len(arr_)):ret[i,:,:,:]=arr_[i]return ret
def to_categorial(y,n_classes):y_std=np.zeros([len(y),n_classes])for i in range(len(y)):y_std[i,y[i]]=1.0return y_std
def batch_list(x,y,batch_size):batches=[]for i in range(int(len(x)/batch_size)):batch_data=[x[batch_size*i:batch_size*i+batch_size],y[batch_size*i:batch_size*i+batch_size]]batches.append(list(batch_data))if(i+1)*batch_size<len(x):batch_data=[x[batch_size*(i+1):],y[batch_size*(i+1):]]batches.append(list(batch_data))return batches
def load_data(dir,num_classes,batch_size):arr_col=convert_3_2_4_dims(convert(dir))arr_col=arr_col.astype(np.float32)#因为这儿我没指定它的标签,所以就随机指定了一些标签z=np.random.rand(arr_col.shape[0])*num_classesz=z.astype("int")labels=np.array(z)batches=batch_list(arr_col,to_categorial(labels,num_classes),batch_size)return batchesdef data_lrn(batches):for batch in batches:batch[0]/=255return batches
更多:
Tensorflow读取并使用预训练模型:以inception_v3为例