Kaldi nnet3的前向計算


  • 根據任務,構建ComputationRequst
  • 編譯ComputationRequst,獲取NnetComputation

    std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

    • 創建計算——CreateComputation

      compiler.CreateComputation(opts, computation);

      • 從輸出節點開始逐步向前計算依賴關系

        ComputationGraphBuilder builder(nnet_, &graph_);

        builder.Compute(*(requests_[segment]));

        每次向前深入一層,並計算所有Cindexes的依賴關系

        BuildGraphOneIter();

        對其中的每個Cindex,若需要計算其依賴:

        AddDependencies(cindex_id);

        • 若為kDescriptordesc.GetDependencies(index, &input_cindexes);
        • 若為kComponentcomponent->GetInputIndexes(request_->misc_info, index, &input_indexes);
        • 若為kDimRangeinput_cindexes[0] = Cindex(node.u.node_index, index);
        • 若為kInput,不需要依賴
      • 檢查是否所有的輸出都是可計算的

        if (!builder.AllOutputsAreComputable())

      • 將數據與運算組織為計算步

        對每個chunkCindexes根據不同網絡層切分為phases,並以chunk為單位進行處理

        steps_computer.ComputeForSegment(*(requests_[segment]),phases_per_segment[segment]);

        phases以節點為單位切分為sub-phases,並以sub-phases為單位進行處理

        ProcessSubPhase(request, sub_phases[j]);

        sub-phases對於節點類型為:

        component-nodeProcessComponentStep(sub_phase);

        kSimpleComponent:除索引數-1外,將step復制為input_step

        else:從graph_->dependencies[c]獲取依賴並插入到input_step

        input-nodeProcessInputOrOutputStep(request, false, sub_phase);

        output-nodeProcessInputOrOutputStep(request, true, sub_phase);

        dim-range-nodeProcessDimRangeSubPhase(sub_phase);

    • 優化計算——Optimize

      Optimize(opt_config_, nnet_,

      MaxOutputTimeInRequest(request),

      computation);

  • 根據NnetComputation構建NnetComputer

    NnetComputer computer(opts_.compute_config, *computation,

    nnet_, nnet_to_update);

  • 運行NnetComputer

    computer.Run();

    NnetComputation中所有Command迭代地運行

    ExecuteCommand();

    kPropagatevoid *memo = component->Propagate(indexes, input, &output);

    kBackpropcomponent->Backprop(debug_str.str(), indexes,

    in_value, out_value, out_deriv,

    memo, upd_component,

    c.arg6 == 0 ? NULL : &in_deriv);

    ...

  • NnetComputer獲取輸出

    computer.GetOutputDestructive("output", &cu_output);


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM