pointNet代碼


介紹

組成

1.PointNet classification network分類網絡

  1. part segmentation network

數據集

1.point clouds sampled from 3D shapes
2.ShapeNetPart dataset.

結構

其主要分成以下三部分:

  • 數據處理
  • model構建
  • 結果選擇

數據處理

將點雲處理成程序可用的格式,具體實現在 provider.py 中,主要包含了數據下載、預處理(shuffle->rotate->jitter)、格式轉換(hdf5->txt)

shuffle

def shuffle_data(data, labels):
    """ Shuffle data and labels.
        Input:
          data: B,N,... numpy array
          label: B,... numpy array
        Return:
          shuffled data, label and shuffle indices
    """
    idx = np.arange(len(labels))#返回一個列表
    # print('idx=',idx)#idx= [   0    1    2 ... 2045 2046 2047]
    np.random.shuffle(idx)#把idx進行shuffle
    # print('idx=', idx)
    return data[idx, ...], labels[idx], idx

rotate旋轉處理

def rotate_point_cloud(batch_data):
    # print('batch data shape=',batch_data.shape)#(32, 1024, 3)
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in range(batch_data.shape[0]):
        rotation_angle = np.random.uniform() * 2 * np.pi#生成一個隨機數
        cosval = np.cos(rotation_angle)
        sinval = np.sin(rotation_angle)
        rotation_matrix = np.array([[cosval, 0, sinval],
                                    [0, 1, 0],
                                    [-sinval, 0, cosval]])
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
        #先讓shape_pc的形狀變成(?,3),因為旋轉矩陣為(3,3)
    return rotated_data

jitter抖動處理

def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
    B, N, C = batch_data.shape
    assert(clip > 0)
    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)#將數組范圍限制在(-1*clip, clip)
    jittered_data += batch_data
    return jittered_data

model構建

Feature transform net

with tf.variable_scope('transform_net1') as sc:#T-net
    transform = input_transform_net(point_cloud, is_training, bn_decay, K=3)
print('point cloud=',point_cloud)#(32, 1024, 3)
# print('input transform=',transform)#(32, 3, 3)
point_cloud_transformed = tf.matmul(point_cloud, transform)
# print('point_cloud_transformed=',point_cloud_transformed)#(32, 1024, 3)

mlp(64,128,1024)

net = tf_util.conv2d(net_transformed, 64, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='conv3', bn_decay=bn_decay)
print('net3=',net)#(32, 1024, 1, 64)
net = tf_util.conv2d(net, 128, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='conv4', bn_decay=bn_decay)
print('net4=',net)#(32, 1024, 1, 128)
net = tf_util.conv2d(net, 1024, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='conv5', bn_decay=bn_decay)
print('net5=',net)#(32, 1024, 1, 1024)

類別投票

實現方法

batch_pred_sum.shape=(?,40) # 每個data對40個類的可能性

pred_val.shape=(?,) # 每個data所屬的可能性最大的類

 pred_val = np.argmax(batch_pred_sum, 1)
 #返回沿軸axis最大值的索引,即得到預測值最大的那一類的idx(label)

評估

輸出(預測label,真實label)

</dump/pred_label.txt>

4, 4    
0, 0
2, 2
8, 8
14, 23
...
<shape_names.txt>

airplane
bathtub
bed
bench
bookshelf
bottle
bowl
car
chair
cone
cup

 

保存預測錯誤的圖片,並可視化

</dump/xxxx_pred_name.jpg>
命名=第幾個預測錯誤的圖片+真實label+預測label

例子 /dump/1028_label_bed_pred_sofa.jpg

 

 三張點雲圖片,分別是當前點雲數據旋轉三個不同角度之后的樣子

save code

  for i in range(start_idx, end_idx):
        l = current_label[i]
        total_seen_class[l] += 1
        total_correct_class[l] += (pred_val[i-start_idx] == l)
        fout.write('%d, %d\n' % (pred_val[i-start_idx], l))
        # print('!!!!!!!!!!','%d, %d\n' % (pred_val[i-start_idx], l))
        if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP!如果預測錯了
            img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l],
                                                   SHAPE_NAMES[pred_val[i-start_idx]])
            #第幾個預測錯誤的圖片+真實label+預測label
            img_filename = os.path.join(DUMP_DIR, img_filename)
            output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :]))
            scipy.misc.imsave(img_filename, output_img)
            error_cnt += 1

畫點雲圖的code

draw_point_cloud()
Input:
points: Nx3 numpy array
Output:
gray image

記錄loss,預測精確度

/dump/log_evaluate.txt

eval mean loss: 1.816358
eval accuracy: 0.501216
eval avg class acc: 0.421297
  airplane: 0.980
   bathtub: 0.440
       bed: 0.940
     bench: 0.450
     ...

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM