本月1日起,上海正式開始了“史上最嚴(yán)“垃圾分類的規(guī)定,扔錯(cuò)垃圾最高可罰200元。全國其它46個(gè)城市也要陸續(xù)步入垃圾分類新時(shí)代。各種被垃圾分類逼瘋的段子在社交媒體上層出不窮。
其實(shí)從人工智能的角度看垃圾分類就是圖像處理中圖像分類任務(wù)的一種應(yīng)用,而這在2012年以來的ImageNet圖像分類任務(wù)的評(píng)比中,SENet模型以top-5測試集回歸2.25%錯(cuò)誤率的成績可謂是技?jí)喝盒郏胺Q目前最強(qiáng)的圖像分類器。
筆者剛剛還到SENet的創(chuàng)造者momenta公司的網(wǎng)站上看了一下,他們最新的方向已經(jīng)是3D物體識(shí)別和標(biāo)定了,效果如下:
可以說他們提出的SENet進(jìn)行垃圾圖像處理是完全沒問題的。
Senet簡介
Senet的是由momenta和牛津大學(xué)共同提出的一種基于擠壓(squeeze)和激勵(lì)(Excitation)的模型,每個(gè)模塊通過“擠壓”操作嵌入來自全局感受野的信息,并且通過“激勵(lì)”操作選擇性地誘導(dǎo)響應(yīng)增強(qiáng)。我們可以看到歷年的ImageNet冠軍基本都是在使用加大模型數(shù)量和連接數(shù)量的方式來提高精度,而Senet在這種”大力出奇跡”的潮流下明顯是一股清流。其論文地址如下:http://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf
其具體原理說明如下:
Sequeeze:對(duì) C×H×W 進(jìn)行 global average pooling,得到 1×1×C 大小的特征圖,這個(gè)特征圖可以理解為具有全局感受野。翻譯論文原文來說:將每個(gè)二維的特征通道變成一個(gè)實(shí)數(shù),這個(gè)實(shí)數(shù)某種程度上具有全局的感受野,并且輸出的維度和輸入的特征通道數(shù)相匹配。它表征著在特征通道上響應(yīng)的全局分布,而且使得靠近輸入的層也可以獲得全局的感受野。
Excitation :使用一個(gè)全連接神經(jīng)網(wǎng)絡(luò),對(duì) Sequeeze 之后的結(jié)果做一個(gè)非線性變換。它的機(jī)制一個(gè)類似于循環(huán)神經(jīng)網(wǎng)絡(luò)中的門。通過參數(shù) w 來為每個(gè)特征通道生成權(quán)重,其中參數(shù) w 被學(xué)習(xí)用來顯式地建模特征通道間的相關(guān)性。
特征重標(biāo)定:使用 Excitation 得到的結(jié)果作為權(quán)重,乘到輸入特征上。將Excitation輸出的權(quán)重可以認(rèn)為是特征通道的重要性反應(yīng),逐通道加權(quán)到放到先前的特征上,完成對(duì)原始特征的重標(biāo)定。
其模型架構(gòu)如下:
SENet 構(gòu)造非常簡單,而且很容易被部署,不需要引入新的函數(shù)或者層。其caffe模型可以通過百度下載(https://pan.baidu.com/s/1o7HdfAE?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=)
Senet的運(yùn)用
如果讀者布署有caffe那么直接下載剛剛的模型直接load進(jìn)來就可以使用了。如果沒有裝caffe而裝了tensorflow也沒關(guān)系,我們剛剛說了SENet沒有引入新的函數(shù)和層,很方便用tensorflow實(shí)現(xiàn)。
下載圖像集:經(jīng)筆者各方查找發(fā)現(xiàn)了這個(gè)數(shù)據(jù)集,雖然不大也沒有發(fā)揮出SENet的優(yōu)勢,不過也方便使用:
https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip
建立SENet模型:使用tensorflow建立的模型在github上也有開源項(xiàng)目了,網(wǎng)址如下:https://github.com/taki0112/SENet-Tensorflow,只是他使用的是Cifar10數(shù)據(jù)集,不過這也沒關(guān)系,只需要在gitclone以下將其cifar10.py中的prepare_data函數(shù)做如下修改即可。
1defprepare_data(): 2print("======Loadingdata======") 3download_data() 4data_dir='e:/test/' 5#data_dir='./cifar-10-batches-py'#改為你的文件俠 6image_dim=image_size*image_size*img_channels 7#meta=unpickle(data_dir+'/batches.meta')#本數(shù)據(jù)集不使用meta文件分類,故需要修改 8label_names=['cardboard','glass','metal','trash','paper','plastic'] 9label_count=len(label_names)10#train_files=['data_batch_%d'%dfordinrange(1,6)]11train_files=[data_dir+sforsinlabel_names]#改為12train_data,train_labels=load_data(train_files,data_dir,label_count)13test_data,test_labels=load_data(['test_batch'],data_dir,label_count)1415print("Traindata:",np.shape(train_data),np.shape(train_labels))16print("Testdata:",np.shape(test_data),np.shape(test_labels))17print("======Loadfinished======")1819print("======Shufflingdata======")20indices=np.random.permutation(len(train_data))21train_data=train_data[indices]22train_labels=train_labels[indices]23print("======PrepareFinished======")2425returntrain_data,train_labels,test_data,test_labels
其最主要的建模代碼如下,其主要工作就是將SENet的模型結(jié)構(gòu)實(shí)現(xiàn)一下即可:
1importtensorflowastf 2fromtflearn.layers.convimportglobal_avg_pool 3fromtensorflow.contrib.layersimportbatch_norm,flatten 4fromtensorflow.contrib.frameworkimportarg_scope 5fromcifar10import* 6importnumpyasnp 7 8weight_decay=0.0005 9momentum=0.9 10 11init_learning_rate=0.1 12 13reduction_ratio=4 14 15batch_size=128 16iteration=391 17#128*391~50,000 18 19test_iteration=10 20 21total_epochs=100 22 23defconv_layer(input,filter,kernel,stride=1,padding='SAME',layer_name="conv",activation=True): 24withtf.name_scope(layer_name): 25network=tf.layers.conv2d(inputs=input,use_bias=True,filters=filter,kernel_size=kernel,strides=stride,padding=padding) 26ifactivation: 27network=Relu(network) 28returnnetwork 29 30defFully_connected(x,units=class_num,layer_name='fully_connected'): 31withtf.name_scope(layer_name): 32returntf.layers.dense(inputs=x,use_bias=True,units=units) 33 34defRelu(x): 35returntf.nn.relu(x) 36 37defSigmoid(x): 38returntf.nn.sigmoid(x) 39 40defGlobal_Average_Pooling(x): 41returnglobal_avg_pool(x,name='Global_avg_pooling') 42 43defMax_pooling(x,pool_size=[3,3],stride=2,padding='VALID'): 44returntf.layers.max_pooling2d(inputs=x,pool_size=pool_size,strides=stride,padding=padding) 45 46defBatch_Normalization(x,training,scope): 47witharg_scope([batch_norm], 48scope=scope, 49updates_collections=None, 50decay=0.9, 51center=True, 52scale=True, 53zero_debias_moving_mean=True): 54returntf.cond(training, 55lambda:batch_norm(inputs=x,is_training=training,reuse=None), 56lambda:batch_norm(inputs=x,is_training=training,reuse=True)) 57 58defConcatenation(layers): 59returntf.concat(layers,axis=3) 60 61defDropout(x,rate,training): 62returntf.layers.dropout(inputs=x,rate=rate,training=training) 63 64defEvaluate(sess): 65test_acc=0.0 66test_loss=0.0 67test_pre_index=0 68add=1000 69 70foritinrange(test_iteration): 71test_batch_x=test_x[test_pre_index:test_pre_index+add] 72test_batch_y=test_y[test_pre_index:test_pre_index+add] 73test_pre_index=test_pre_index+add 74 75test_feed_dict={ 76x:test_batch_x, 77label:test_batch_y, 78learning_rate:epoch_learning_rate, 79training_flag:False 80} 81 82loss_,acc_=sess.run([cost,accuracy],feed_dict=test_feed_dict) 83 84test_loss+=loss_ 85test_acc+=acc_ 86 87test_loss/=test_iteration#averageloss 88test_acc/=test_iteration#averageaccuracy 89 90summary=tf.Summary(value=[tf.Summary.Value(tag='test_loss',simple_value=test_loss), 91tf.Summary.Value(tag='test_accuracy',simple_value=test_acc)]) 92 93returntest_acc,test_loss,summary 94 95classSE_Inception_resnet_v2(): 96def__init__(self,x,training): 97self.training=training 98self.model=self.Build_SEnet(x) 99100defStem(self,x,scope):101withtf.name_scope(scope):102x=conv_layer(x,filter=32,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_conv1')103x=conv_layer(x,filter=32,kernel=[3,3],padding='VALID',layer_name=scope+'_conv2')104block_1=conv_layer(x,filter=64,kernel=[3,3],layer_name=scope+'_conv3')105106split_max_x=Max_pooling(block_1)107split_conv_x=conv_layer(block_1,filter=96,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv1')108x=Concatenation([split_max_x,split_conv_x])109110split_conv_x1=conv_layer(x,filter=64,kernel=[1,1],layer_name=scope+'_split_conv2')111split_conv_x1=conv_layer(split_conv_x1,filter=96,kernel=[3,3],padding='VALID',layer_name=scope+'_split_conv3')112113split_conv_x2=conv_layer(x,filter=64,kernel=[1,1],layer_name=scope+'_split_conv4')114split_conv_x2=conv_layer(split_conv_x2,filter=64,kernel=[7,1],layer_name=scope+'_split_conv5')115split_conv_x2=conv_layer(split_conv_x2,filter=64,kernel=[1,7],layer_name=scope+'_split_conv6')116split_conv_x2=conv_layer(split_conv_x2,filter=96,kernel=[3,3],padding='VALID',layer_name=scope+'_split_conv7')117118x=Concatenation([split_conv_x1,split_conv_x2])119120split_conv_x=conv_layer(x,filter=192,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv8')121split_max_x=Max_pooling(x)122123x=Concatenation([split_conv_x,split_max_x])124125x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')126x=Relu(x)127128returnx129130defInception_resnet_A(self,x,scope):131withtf.name_scope(scope):132init=x133134split_conv_x1=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv1')135136split_conv_x2=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv2')137split_conv_x2=conv_layer(split_conv_x2,filter=32,kernel=[3,3],layer_name=scope+'_split_conv3')138139split_conv_x3=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv4')140split_conv_x3=conv_layer(split_conv_x3,filter=48,kernel=[3,3],layer_name=scope+'_split_conv5')141split_conv_x3=conv_layer(split_conv_x3,filter=64,kernel=[3,3],layer_name=scope+'_split_conv6')142143x=Concatenation([split_conv_x1,split_conv_x2,split_conv_x3])144x=conv_layer(x,filter=384,kernel=[1,1],layer_name=scope+'_final_conv1',activation=False)145146x=x*0.1147x=init+x148149x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')150x=Relu(x)151152returnx153154defInception_resnet_B(self,x,scope):155withtf.name_scope(scope):156init=x157158split_conv_x1=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv1')159160split_conv_x2=conv_layer(x,filter=128,kernel=[1,1],layer_name=scope+'_split_conv2')161split_conv_x2=conv_layer(split_conv_x2,filter=160,kernel=[1,7],layer_name=scope+'_split_conv3')162split_conv_x2=conv_layer(split_conv_x2,filter=192,kernel=[7,1],layer_name=scope+'_split_conv4')163164x=Concatenation([split_conv_x1,split_conv_x2])165x=conv_layer(x,filter=1152,kernel=[1,1],layer_name=scope+'_final_conv1',activation=False)166#1154167x=x*0.1168x=init+x169170x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')171x=Relu(x)172173returnx174175defInception_resnet_C(self,x,scope):176withtf.name_scope(scope):177init=x178179split_conv_x1=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv1')180181split_conv_x2=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv2')182split_conv_x2=conv_layer(split_conv_x2,filter=224,kernel=[1,3],layer_name=scope+'_split_conv3')183split_conv_x2=conv_layer(split_conv_x2,filter=256,kernel=[3,1],layer_name=scope+'_split_conv4')184185x=Concatenation([split_conv_x1,split_conv_x2])186x=conv_layer(x,filter=2144,kernel=[1,1],layer_name=scope+'_final_conv2',activation=False)187#2048188x=x*0.1189x=init+x190191x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')192x=Relu(x)193194returnx195196defReduction_A(self,x,scope):197withtf.name_scope(scope):198k=256199l=256200m=384201n=384202203split_max_x=Max_pooling(x)204205split_conv_x1=conv_layer(x,filter=n,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv1')206207split_conv_x2=conv_layer(x,filter=k,kernel=[1,1],layer_name=scope+'_split_conv2')208split_conv_x2=conv_layer(split_conv_x2,filter=l,kernel=[3,3],layer_name=scope+'_split_conv3')209split_conv_x2=conv_layer(split_conv_x2,filter=m,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv4')210211x=Concatenation([split_max_x,split_conv_x1,split_conv_x2])212213x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')214x=Relu(x)215216returnx217218defReduction_B(self,x,scope):219withtf.name_scope(scope):220split_max_x=Max_pooling(x)221222split_conv_x1=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv1')223split_conv_x1=conv_layer(split_conv_x1,filter=384,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv2')224225split_conv_x2=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv3')226split_conv_x2=conv_layer(split_conv_x2,filter=288,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv4')227228split_conv_x3=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv5')229split_conv_x3=conv_layer(split_conv_x3,filter=288,kernel=[3,3],layer_name=scope+'_split_conv6')230split_conv_x3=conv_layer(split_conv_x3,filter=320,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv7')231232x=Concatenation([split_max_x,split_conv_x1,split_conv_x2,split_conv_x3])233234x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')235x=Relu(x)236237returnx238239defSqueeze_excitation_layer(self,input_x,out_dim,ratio,layer_name):240withtf.name_scope(layer_name):241242243squeeze=Global_Average_Pooling(input_x)244245excitation=Fully_connected(squeeze,units=out_dim/ratio,layer_name=layer_name+'_fully_connected1')246excitation=Relu(excitation)247excitation=Fully_connected(excitation,units=out_dim,layer_name=layer_name+'_fully_connected2')248excitation=Sigmoid(excitation)249250excitation=tf.reshape(excitation,[-1,1,1,out_dim])251scale=input_x*excitation252253returnscale254255defBuild_SEnet(self,input_x):256input_x=tf.pad(input_x,[[0,0],[32,32],[32,32],[0,0]])257#size32->96258print(np.shape(input_x))259#onlycifar10architecture260261x=self.Stem(input_x,scope='stem')262263foriinrange(5):264x=self.Inception_resnet_A(x,scope='Inception_A'+str(i))265channel=int(np.shape(x)[-1])266x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_A'+str(i))267268x=self.Reduction_A(x,scope='Reduction_A')269270channel=int(np.shape(x)[-1])271x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_A')272273foriinrange(10):274x=self.Inception_resnet_B(x,scope='Inception_B'+str(i))275channel=int(np.shape(x)[-1])276x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_B'+str(i))277278x=self.Reduction_B(x,scope='Reduction_B')279280channel=int(np.shape(x)[-1])281x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_B')282283foriinrange(5):284x=self.Inception_resnet_C(x,scope='Inception_C'+str(i))285channel=int(np.shape(x)[-1])286x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_C'+str(i))287288289#channel=int(np.shape(x)[-1])290#x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_C')291292x=Global_Average_Pooling(x)293x=Dropout(x,rate=0.2,training=self.training)294x=flatten(x)295296x=Fully_connected(x,layer_name='final_fully_connected')297returnx298299300train_x,train_y,test_x,test_y=prepare_data()301train_x,test_x=color_preprocessing(train_x,test_x)302303304#image_size=32,img_channels=3,class_num=10incifar10305x=tf.placeholder(tf.float32,shape=[None,image_size,image_size,img_channels])306label=tf.placeholder(tf.float32,shape=[None,class_num])307308training_flag=tf.placeholder(tf.bool)309310311learning_rate=tf.placeholder(tf.float32,name='learning_rate')312313logits=SE_Inception_resnet_v2(x,training=training_flag).model314cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logits))315316l2_loss=tf.add_n([tf.nn.l2_loss(var)forvarintf.trainable_variables()])317optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=momentum,use_nesterov=True)318train=optimizer.minimize(cost+l2_loss*weight_decay)319320correct_prediction=tf.equal(tf.argmax(logits,1),tf.argmax(label,1))321accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))322323saver=tf.train.Saver(tf.global_variables())324325withtf.Session()assess:326ckpt=tf.train.get_checkpoint_state('./model')327ifckptandtf.train.checkpoint_exists(ckpt.model_checkpoint_path):328saver.restore(sess,ckpt.model_checkpoint_path)329else:330sess.run(tf.global_variables_initializer())331332summary_writer=tf.summary.FileWriter('./logs',sess.graph)333334epoch_learning_rate=init_learning_rate335forepochinrange(1,total_epochs+1):336ifepoch%30==0:337epoch_learning_rate=epoch_learning_rate/10338339pre_index=0340train_acc=0.0341train_loss=0.0342343forstepinrange(1,iteration+1):344ifpre_index+batch_size50000:345????????????????batch_x?=?train_x[pre_index:?pre_index?+?batch_size]346????????????????batch_y?=?train_y[pre_index:?pre_index?+?batch_size]347????????????else:348????????????????batch_x?=?train_x[pre_index:]349????????????????batch_y?=?train_y[pre_index:]350351????????????batch_x?=?data_augmentation(batch_x)352353????????????train_feed_dict?=?{354????????????????x:?batch_x,355????????????????label:?batch_y,356????????????????learning_rate:?epoch_learning_rate,357????????????????training_flag:?True358????????????}359360????????????_,?batch_loss?=?sess.run([train,?cost],?feed_dict=train_feed_dict)361????????????batch_acc?=?accuracy.eval(feed_dict=train_feed_dict)362363????????????train_loss?+=?batch_loss364????????????train_acc?+=?batch_acc365????????????pre_index?+=?batch_size366367368????????train_loss?/=?iteration?#?average?loss369????????train_acc?/=?iteration?#?average?accuracy370371????????train_summary?=?tf.Summary(value=[tf.Summary.Value(tag='train_loss',?simple_value=train_loss),372??????????????????????????????????????????tf.Summary.Value(tag='train_accuracy',?simple_value=train_acc)])373374????????test_acc,?test_loss,?test_summary?=?Evaluate(sess)375376????????summary_writer.add_summary(summary=train_summary,?global_step=epoch)377????????summary_writer.add_summary(summary=test_summary,?global_step=epoch)378????????summary_writer.flush()379380????????line?=?"epoch:?%d/%d,?train_loss:?%.4f,?train_acc:?%.4f,?test_loss:?%.4f,?test_acc:?%.4f? "?%?(381????????????epoch,?total_epochs,?train_loss,?train_acc,?test_loss,?test_acc)382????????print(line)383384????????with?open('logs.txt',?'a')?as?f:385????????????f.write(line)386387????????saver.save(sess=sess,?save_path='./model/Inception_resnet_v2.ckpt')
其實(shí)使用SENet做垃圾分類真是大才小用了,不過大家也可以感受一下他的實(shí)力強(qiáng)大。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4778瀏覽量
101023 -
圖像處理
+關(guān)注
關(guān)注
27文章
1299瀏覽量
56837 -
人工智能
+關(guān)注
關(guān)注
1793文章
47604瀏覽量
239532
原文標(biāo)題:還在糾結(jié)垃圾分類問題?帶你用Python感受ImageNet冠軍模型SENet的強(qiáng)大
文章出處:【微信號(hào):rgznai100,微信公眾號(hào):rgznai100】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評(píng)論請先 登錄
相關(guān)推薦
評(píng)論