作者:大森林| 來源:3DCV
1. NeRF定義
神經輻射場(NeRF)是一種利用神經網絡來表示和渲染復雜的三維場景的方法。它可以從一組二維圖片中學習出一個連續的三維函數,這個函數可以給出空間中任意位置和方向上的顏色和密度。通過體積渲染的技術,NeRF可以從任意視角合成出逼真的圖像,包括透明和半透明物體,以及復雜的光線傳播效果。
2. NeRF優勢
NeRF模型相比于其他新的視圖合成和場景表示方法有以下幾個優勢:
1)NeRF不需要離散化的三維表示,如網格或體素,因此可以避免模型精度和細節程度受到限制。NeRF也可以自適應地處理不同形狀和大小的場景,而不需要人工調整參數。
2)NeRF使用位置編碼的方式將位置和角度信息映射到高頻域,使得網絡能夠更好地捕捉場景的細微結構和變化。NeRF還使用視角相關的顏色預測,能夠生成不同視角下不同的光照效果。
3)NeRF使用分段隨機采樣的方式來近似體積渲染的積分,這樣可以保證采樣位置的連續性,同時避免網絡過擬合于離散點的信息。NeRF還使用多層級體素采樣的技巧,以提高渲染效率和質量。
3. NeRF實現步驟
1)定義一個全連接的神經網絡,它的輸入是空間位置和視角方向,輸出是顏色和密度。
2)使用位置編碼的方式將輸入映射到高頻域,以便網絡能夠捕捉細微的結構和變化。
3)使用分段隨機采樣的方式從每條光線上采樣一些點,然后用神經網絡預測這些點的顏色和密度。
4)使用體積渲染的公式計算每條光線上的顏色和透明度,作為最終的圖像輸出。
5)使用渲染損失函數來優化神經網絡的參數,使得渲染的圖像與輸入的圖像盡可能接近。
importtorch importtorch.nnasnn importtorch.nn.functionalasF #定義一個全連接的神經網絡,它的輸入是空間位置和視角方向,輸出是顏色和密度。 classNeRF(nn.Module): def__init__(self,D=8,W=256,input_ch=3,input_ch_views=3,output_ch=4,skips=[4]): super().__init__() #定義位置編碼后的位置信息的線性層,如果層數在skips列表中,則將原始位置信息與隱藏層拼接 self.pts_linears=nn.ModuleList( [nn.Linear(input_ch,W)]+[nn.Linear(W,W)ifinotinskipselsenn.Linear(W+input_ch,W)foriinrange(D-1)]) #定義位置編碼后的視角方向信息的線性層 self.views_linears=nn.ModuleList([nn.Linear(W+input_ch_views,W//2)]+[nn.Linear(W//2,W//2)foriinrange(1)]) #定義特征向量的線性層 self.feature_linear=nn.Linear(W//2,W) #定義透明度(alpha)值的線性層 self.alpha_linear=nn.Linear(W,1) #定義RGB顏色的線性層 self.rgb_linear=nn.Linear(W+input_ch_views,3) defforward(self,x): #x:(B,input_ch+input_ch_views) #提取位置和視角方向信息 p=x[:,:3]#(B,3) d=x[:,3:]#(B,3) #對輸入進行位置編碼,將低頻信號映射到高頻域 p=positional_encoding(p)#(B,input_ch) d=positional_encoding(d)#(B,input_ch_views) #將位置信息輸入網絡 h=p fori,linenumerate(self.pts_linears): h=l(h) h=F.relu(h) ifiinskips: h=torch.cat([h,p],-1)#如果層數在skips列表中,則將原始位置信息與隱藏層拼接 #將視角方向信息與隱藏層拼接,并輸入網絡 h=torch.cat([h,d],-1) fori,linenumerate(self.views_linears): h=l(h) h=F.relu(h) #預測特征向量和透明度(alpha)值 feature=self.feature_linear(h)#(B,W) alpha=self.alpha_linear(feature)#(B,1) #使用特征向量和視角方向信息預測RGB顏色 rgb=torch.cat([feature,d],-1) rgb=self.rgb_linear(rgb)#(B,3) returntorch.cat([rgb,alpha],-1)#(B,4) #定義位置編碼函數 defpositional_encoding(x): #x:(B,C) B,C=x.shape L=int(C//2)#計算位置編碼的長度 freqs=torch.logspace(0.,L-1,steps=L).to(x.device)*math.pi#計算頻率系數,呈指數增長 freqs=freqs[None].repeat(B,1)#(B,L) x_pos_enc_low=torch.sin(x[:,:L]*freqs)#對前一半的輸入進行正弦變換,得到低頻部分(B,L) x_pos_enc_high=torch.cos(x[:,:L]*freqs)#對前一半的輸入進行余弦變換,得到高頻部分(B,L) x_pos_enc=torch.cat([x_pos_enc_low,x_pos_enc_high],dim=-1)#將低頻和高頻部分拼接,得到位置編碼后的輸入(B,C) returnx_pos_enc #定義體積渲染函數 defvolume_rendering(rays_o,rays_d,model): #rays_o:(B,3),每條光線的起點 #rays_d:(B,3),每條光線的方向 B=rays_o.shape[0] #在每條光線上采樣一些點 near,far=0.,1.#近平面和遠平面 N_samples=64#每條光線的采樣數 t_vals=torch.linspace(near,far,N_samples).to(rays_o.device)#(N_samples,) t_vals=t_vals.expand(B,N_samples)#(B,N_samples) z_vals=near*(1.-t_vals)+far*t_vals#計算每個采樣點的深度值(B,N_samples) z_vals=z_vals.unsqueeze(-1)#(B,N_samples,1) pts=rays_o.unsqueeze(1)+rays_d.unsqueeze(1)*z_vals#計算每個采樣點的空間位置(B,N_samples,3) #將采樣點和視角方向輸入網絡 pts_flat=pts.reshape(-1,3)#(B*N_samples,3) rays_d_flat=rays_d.unsqueeze(1).expand(-1,N_samples,-1).reshape(-1,3)#(B*N_samples,3) x_flat=torch.cat([pts_flat,rays_d_flat],-1)#(B*N_samples,6) y_flat=model(x_flat)#(B*N_samples,4) y=y_flat.reshape(B,N_samples,4)#(B,N_samples,4) #提取RGB顏色和透明度(alpha)值 rgb=y[...,:3]#(B,N_samples,3) alpha=y[...,3]#(B,N_samples) #計算每個采樣點的權重 dists=torch.cat([z_vals[...,1:]-z_vals[...,:-1],torch.tensor([1e10]).to(z_vals.device).expand(B,1)],-1)#計算相鄰采樣點之間的距離,最后一個距離設為很大的值(B,N_samples) alpha=1.-torch.exp(-alpha*dists)#計算每個采樣點的不透明度,即1減去透明度的指數衰減(B,N_samples) weights=alpha*torch.cumprod(torch.cat([torch.ones((B,1)).to(alpha.device),1.-alpha+1e-10],-1),-1)[:,:-1]#計算每個采樣點的權重,即不透明度乘以之前所有采樣點的透明度累積積,最后一個權重設為0(B,N_samples) #計算每條光線的最終顏色和透明度 rgb_map=torch.sum(weights.unsqueeze(-1)*rgb,-2)#加權平均每個采樣點的RGB顏色,得到每條光線的顏色(B,3) depth_map=torch.sum(weights*z_vals.squeeze(-1),-1)#加權平均每個采樣點的深度值,得到每條光線的深度(B,) acc_map=torch.sum(weights,-1)#累加每個采樣點的權重,得到每條光線的不透明度(B,) returnrgb_map,depth_map,acc_map #定義渲染損失函數 defrendering_loss(rgb_map_pred,rgb_map_gt): return((rgb_map_pred-rgb_map_gt)**2).mean()#計算預測的顏色與真實顏色之間的均方誤差
綜上所述,本代碼實現了NeRF的核心結構,具體實現內容包括以下四個部分。
1)定義了NeRF網絡結構,包含位置編碼和多層全連接網絡,輸入是位置和視角,輸出是顏色和密度。
2)實現了位置編碼函數,通過正弦和余弦變換引入高頻信息。
3)實現了體積渲染函數,在光線上采樣點,查詢NeRF網絡預測顏色和密度,然后通過加權平均實現整體渲染。
4)定義了渲染損失函數,計算預測顏色和真實顏色的均方誤差。
當然,本方案只是實現NeRF的一個基礎方案,更多的細節還需要進行優化。
當然,為了方便下載,我們已經將上述兩個源代碼打包好了。
審核編輯:湯梓紅
-
神經網絡
+關注
關注
42文章
4779瀏覽量
101030 -
函數
+關注
關注
3文章
4345瀏覽量
62867 -
代碼
+關注
關注
30文章
4821瀏覽量
68890 -
pytorch
+關注
關注
2文章
808瀏覽量
13322
原文標題:一文帶你入門NeRF:利用PyTorch實現NeRF代碼詳解(附代碼)
文章出處:【微信號:3D視覺工坊,微信公眾號:3D視覺工坊】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論