pyspark對應的scala代碼PythonRDD類


pyspark jvm端的scala代碼PythonRDD

代碼版本為 spark 2.2.0

1.PythonRDD.class

這個rdd類型是python能接入spark的關鍵

//這是一個標准的RDD實現,實現對應的compute,partitioner,getPartitions等方法
//這個PythonRDD就是pyspark里PipelinedRDD里_jrdd屬性方法返回的東西
//parent就是PipelinedRDD里傳遞進來的_prev_jrdd,是最初構建的數據源RDD
private[spark] class PythonRDD(
    parent: RDD[_],  //這個parentRDD是關鍵,python使用spark的所有數據來源都從這里來的
    func: PythonFunction, //這個是用戶實現的python計算邏輯
    preservePartitoning: Boolean)
  extends RDD[Array[Byte]](parent) {

  val bufferSize = conf.getInt("spark.buffer.size", 65536)
  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

  override def getPartitions: Array[Partition] = firstParent.partitions

  override val partitioner: Option[Partitioner] = {
    if (preservePartitoning) firstParent.partitioner else None
  }

  val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
    //調用PythonRunner執行此處任務邏輯
    //這里這個PythonRunner跟spark-submit時執行的PythonRunner不是同一個東西
    val runner = PythonRunner(func, bufferSize, reuse_worker)
    //執行runner的計算邏輯,第一個參數是spark數據源rdd的計算結果
    //firstParent.iterator會觸發parent 這個rdd的計算,返回計算結果
    //這里第一個參數的rdd跟pyspark中RDD里的_jrdd是同一個東西
    runner.compute(firstParent.iterator(split, context), split.index, context)
  }
}

2.PythonRunner.class

這個類是rdd內部執行計算時的實體計算類,並不是代碼提交時那個啟動py4j的PythonRunner

/*
 * 這個類做了三件事
 * 1.啟動pyspark.daemon 接收task啟動work執行接收到的task
 * 2.啟動writerThread 將數據源的計算結果寫到pyspark.work中
 * 3.從pyspark.work中拉取執行結果
 * 
 * writerThread寫的數據就是pyspark中_jrdd計算出來的結果,也就是數據源rdd的數據
 */
private[spark] class PythonRunner(
    funcs: Seq[ChainedPythonFunctions],
    bufferSize: Int,
    reuse_worker: Boolean,
    isUDF: Boolean,
    argOffsets: Array[Array[Int]])
  extends Logging {

  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

  //python執行的環境和命令
  private val envVars = funcs.head.funcs.head.envVars
  private val pythonExec = funcs.head.funcs.head.pythonExec
  private val pythonVer = funcs.head.funcs.head.pythonVer

  private val accumulator = funcs.head.funcs.head.accumulator

  def compute(
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext): Iterator[Array[Byte]] = {
    val startTime = System.currentTimeMillis
    val env = SparkEnv.get
    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
    if (reuse_worker) {
      envVars.put("SPARK_REUSE_WORKER", "1")
    }
    
    //創建pyspark 的work進程,底層執行的是pyspark.daemon
    //這個方法保證一次任務只啟動一個pyspark.daemon
    //返回結果是跟work通信用的socket
    //具體分析將在其它部分記錄
    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
    @volatile var released = false

    // 創建writerThread,把數據源數據寫到socket,發送到pyspark.work
    val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)

    //注冊task完成監聽,完成后停止writerThread線程
    context.addTaskCompletionListener { context =>
      writerThread.shutdownOnTaskCompletion()
      if (!reuse_worker || !released) {
        try {
          worker.close()
        } catch {
          case e: Exception =>
            logWarning("Failed to close worker socket", e)
        }
      }
    }

    writerThread.start()
    new MonitorThread(env, worker, context).start()

    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
    // 創建拉取pyspark.work執行結果的迭代器
    val stdoutIterator = new Iterator[Array[Byte]] {
      override def next(): Array[Byte] = {
        val obj = _nextObj
        if (hasNext) {
          _nextObj = read()
        }
        obj
      }

      private def read(): Array[Byte] = {
        if (writerThread.exception.isDefined) {
          throw writerThread.exception.get
        }
        try {
          stream.readInt() match {
            case length if length > 0 =>
              val obj = new Array[Byte](length)
              stream.readFully(obj)
              obj
            case 0 => Array.empty[Byte]
            case SpecialLengths.TIMING_DATA =>
              // Timing data from worker
              val bootTime = stream.readLong()
              val initTime = stream.readLong()
              val finishTime = stream.readLong()
              val boot = bootTime - startTime
              val init = initTime - bootTime
              val finish = finishTime - initTime
              val total = finishTime - startTime
              logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                init, finish))
              val memoryBytesSpilled = stream.readLong()
              val diskBytesSpilled = stream.readLong()
              context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
              context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
              read()
            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
              // Signals that an exception has been thrown in python
              val exLength = stream.readInt()
              val obj = new Array[Byte](exLength)
              stream.readFully(obj)
              throw new PythonException(new String(obj, StandardCharsets.UTF_8),
                writerThread.exception.getOrElse(null))
            case SpecialLengths.END_OF_DATA_SECTION =>
              // We've finished the data section of the output, but we can still
              // read some accumulator updates:
              val numAccumulatorUpdates = stream.readInt()
              (1 to numAccumulatorUpdates).foreach { _ =>
                val updateLen = stream.readInt()
                val update = new Array[Byte](updateLen)
                stream.readFully(update)
                accumulator.add(update)
              }
              // Check whether the worker is ready to be re-used.
              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
                if (reuse_worker) {
                  env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
                  released = true
                }
              }
              null
          }
        } catch {

          case e: Exception if context.isInterrupted =>
            logDebug("Exception thrown after task interruption", e)
            throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

          case e: Exception if env.isStopped =>
            logDebug("Exception thrown after context is stopped", e)
            null  // exit silently

          case e: Exception if writerThread.exception.isDefined =>
            logError("Python worker exited unexpectedly (crashed)", e)
            logError("This may have been caused by a prior exception:", writerThread.exception.get)
            throw writerThread.exception.get

          case eof: EOFException =>
            throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
        }
      }

      var _nextObj = read()

      override def hasNext: Boolean = _nextObj != null
    }
    //返回這個拉取數據結果的迭代器
    new InterruptibleIterator(context, stdoutIterator)
  }

  /**
   * WriterThread 線程的實現代碼
   */
  class WriterThread(
      env: SparkEnv,
      worker: Socket,
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext)
    extends Thread(s"stdout writer for $pythonExec") {

    @volatile private var _exception: Exception = null

    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

    setDaemon(true)

    /** Contains the exception thrown while writing the parent iterator to the Python process. */
    def exception: Option[Exception] = Option(_exception)

    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
    def shutdownOnTaskCompletion() {
      assert(context.isCompleted)
      this.interrupt()
    }

    // 主要邏輯在run里,把數據源rdd的執行結果寫進去
    // 把廣播變量和環境,以及python的執行邏輯代碼寫進去
    // 把需要計算的數據源數據寫進去
    override def run(): Unit = Utils.logUncaughtExceptions {
      try {
        TaskContext.setTaskContext(context)
        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
        val dataOut = new DataOutputStream(stream)
        // Partition index
        dataOut.writeInt(partitionIndex)
        // Python version of driver
        PythonRDD.writeUTF(pythonVer, dataOut)
        // Write out the TaskContextInfo
        dataOut.writeInt(context.stageId())
        dataOut.writeInt(context.partitionId())
        dataOut.writeInt(context.attemptNumber())
        dataOut.writeLong(context.taskAttemptId())
        // sparkFilesDir
        PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
        // Python includes (*.zip and *.egg files)
        dataOut.writeInt(pythonIncludes.size)
        for (include <- pythonIncludes) {
          PythonRDD.writeUTF(include, dataOut)
        }
        // Broadcast variables
        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
        val newBids = broadcastVars.map(_.id).toSet
        // number of different broadcasts
        val toRemove = oldBids.diff(newBids)
        val cnt = toRemove.size + newBids.diff(oldBids).size
        dataOut.writeInt(cnt)
        for (bid <- toRemove) {
          // remove the broadcast from worker
          dataOut.writeLong(- bid - 1)  // bid >= 0
          oldBids.remove(bid)
        }
        for (broadcast <- broadcastVars) {
          if (!oldBids.contains(broadcast.id)) {
            // send new broadcast
            dataOut.writeLong(broadcast.id)
            PythonRDD.writeUTF(broadcast.value.path, dataOut)
            oldBids.add(broadcast.id)
          }
        }
        dataOut.flush()
        // Serialized command:
        if (isUDF) {
          dataOut.writeInt(1)
          dataOut.writeInt(funcs.length)
          funcs.zip(argOffsets).foreach { case (chained, offsets) =>
            dataOut.writeInt(offsets.length)
            offsets.foreach { offset =>
              dataOut.writeInt(offset)
            }
            dataOut.writeInt(chained.funcs.length)
            chained.funcs.foreach { f =>
              dataOut.writeInt(f.command.length)
              dataOut.write(f.command)
            }
          }
        } else {
          dataOut.writeInt(0)
          val command = funcs.head.funcs.head.command
          dataOut.writeInt(command.length)
          dataOut.write(command)
        }
        // Data values
        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
        dataOut.flush()
      } catch {
        case e: Exception if context.isCompleted || context.isInterrupted =>
          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }

        case e: Exception =>
          // We must avoid throwing exceptions here, because the thread uncaught exception handler
          // will kill the whole executor (see org.apache.spark.executor.Executor).
          _exception = e
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }
      }
    }
  }

  // 監控task是不是還在執行
  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
    extends Thread(s"Worker Monitor for $pythonExec") {

    setDaemon(true)

    override def run() {
      // Kill the worker if it is interrupted, checking until task completion.
      // TODO: This has a race condition if interruption occurs, as completed may still become true.
      while (!context.isInterrupted && !context.isCompleted) {
        Thread.sleep(2000)
      }
      if (!context.isCompleted) {
        try {
          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
          env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
        } catch {
          case e: Exception =>
            logError("Exception when trying to kill worker", e)
        }
      }
    }
  }
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM