本系列基于spark-2.4.6
通過上一節的分析,我們了解了Spark中ShuflleMapTask中Map端資料的寫入流程,這個章節我們分析下Reduce端是如何讀取資料的,
在ShulleMapTask.runTask中,有這么一個步驟:
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
其中rdd.iterator:
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}
最后都會呼叫RDD·如下方法:
def compute(split: Partition, context: TaskContext): Iterator[T]
而RDD有多重實作,我們看看RDD中groupBy,回傳的是一個ShuffledRDD,而ShuffledRDD中對應的compute實作如下:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
這里的read實作在BlockStoreShuffleReader中:
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
val resultIter = dep.keyOrdering match {
val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener[Unit](_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
這里首先需要注意下mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),,這是去master獲取當前節點需要獲取的shuffle資料,
重要的邏輯在ShuffleBlockFetcherIterator,另外這里需要注意幾個引數:
spark.reducer.maxSizeInFlight
spark.reducer.maxReqsInFlight
ShuffleBlockFetcherIterator在生成后立馬執行初始化方法initialize:
private[this] def initialize(): Unit = {
context.addTaskCompletionListener[Unit](_ => cleanup())
val remoteRequests = splitLocalRemoteBlocks()
fetchRequests ++= Utils.randomize(remoteRequests)
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
fetchLocalBlocks()
}
首先通過splitLocalRemoteBlocks,劃分需要拉取哪些資料:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
case Some((blockId, size)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
case Some((blockId, size)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
localBlocks ++= blockInfos.map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
}
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
remoteRequests
}
可以看到這里會區分需要拉取的資料是本地資料還是遠程資料(這里資料用Block表示),如果是本地資料則會放入把資料對應的BlockId放入到localBlocks集合中,如果是遠端的資料,這里是按照一個節點一個節點來遍歷節點下的所有資料,是按照節點來拉取節點上的所有資料,這里會判斷當前節點遍歷的Block,如果遍歷到當前Block,所有Block的大小 >= targetRequestSize 或者Block的個數大于maxBlocksInFlightPerAddress的時候,則會將已經遍歷當前節點的Block放到一次請求中去拉取資料,這里的targetRequestSize是前面說的"spark.reducer.maxSizeInFlight/5這里除以5是為了增加并行度maxBlocksInFlightPerAddress則是每次請求一個節點額資料最多請求多少個Block,默認情況下這個是Int.MAX.到這里就將本地和遠端需要拉取的資料分好了,然后會通過fetchUpToMaxBytes獲取對應節點上的Block的資訊,然后拉取Block資料,
發送拉取資料請求sendRequest,這里需要注意有一個處理邏輯:
if (req.size > maxReqSizeShuffleToMem) {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
這里會判斷拉取資料的大小,如果待拉取的資料大小> maxReqSizeShuffleToMem ,那么會將資料寫入到本地磁盤,這里的maxReqSizeShuffleToMem通過spark.maxRemoteBlockSizeFetchToMem來配置,默認是Int.MaxValue - 512 位元組
最侄訓呼叫NettyBlockTransferService.fetchBlocks:
override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
tempFileManager: DownloadFileManager): Unit = {
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempFileManager).start()
}
}
val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
}
可以看到最后啟動了OneForOneBlockFetcher:
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
for (int i = 0; i < streamHandle.numChunks; i++) {
if (downloadFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
}
} catch (Exception e) {
failRemainingBlocks(blockIds, e);
}
}
@Override
public void onFailure(Throwable e) {
failRemainingBlocks(blockIds, e);
}
});
}
這里先是給要拉取資料的節點發送了一個OpenBlocks資訊,如果成功后,則會呼叫TransportClient獲取對飲的資料,這里會判斷downloadFileManager是否為空,就是上面說的這個條件如果待拉取的資料大小> maxReqSizeShuffleToMem,如果滿足則要寫檔案downloadFileManager不為空,否則直接寫記憶體,
- 寫檔案方式最后底層是發送了一個
StreamRequest請求 - 寫記憶體方式發送了一個
ChunkFetchRequest請求
同時,當節點回傳成功之后,會通過對應Callback進行處理:
public void stream(String streamId, StreamCallback callback) {
StdChannelListener listener = new StdChannelListener(streamId) {
void handleFailure(String errorMsg, Throwable cause) throws Exception {
callback.onFailure(streamId, new IOException(errorMsg, cause));
}
};
synchronized (this) {
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
}
}
public void fetchChunk(
long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
StdChannelListener listener = new StdChannelListener(streamChunkId) {
void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
}
};
handler.addFetchRequest(streamChunkId, callback);
channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
}
到這里reduce單已經將請求發送出去,接下來我們看下對應節點daunt怎么相應對飲請求
首先是對應OpenBlocks請求,最后在NettyBlockRpcServer進行處理:
override def receive(
client: TransportClient,
rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")
message match {
case openBlocks: OpenBlocks =>
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
client.getChannel)
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
case uploadBlock: UploadBlock =>
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
val (level: StorageLevel, classTag: ClassTag[_]) = {
serializer
.newInstance()
.deserialize(ByteBuffer.wrap(uploadBlock.metadata))
.asInstanceOf[(StorageLevel, ClassTag[_])]
}
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} " +
s"from ${client.getSocketAddress}")
blockManager.putBlockData(blockId, data, level, classTag)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
這里會對每個請求注冊一個StreamId和對應的StreamState,回傳個拉取端一個StreamHandle資訊,包含了StreamId和Block的個數,在開始的時候會把每個要拉取的Block的資料讀取出來通過getBlockData實作:
override def getBlockData(blockId: BlockId): ManagedBuffer = {
if (blockId.isShuffle) {
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
getLocalBytes(blockId) match {
case Some(blockData) =>
new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true)
case None =>
reportBlockStatus(blockId, BlockStatus.empty)
throw new BlockNotFoundException(blockId.toString)
}
}
}
這里我們是reduce讀取,blockId.isShuffle=true
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
val buf = new ChunkedByteBuffer( shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
Some(new ByteBufferBlockData(buf, true))
最后通過IndexShuffleBlockResolver來進行讀取,這也就是上一節我們說的,Map端的寫入同時會生成一個索引檔案,這里會通過所以檔案獲取對應資料的資訊:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
val channel = Files.newByteChannel(indexFile.toPath)
channel.position(blockId.reduceId * 8L)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
val expectedPosition = blockId.reduceId * 8L + 16
if (actualPosition != expectedPosition) {
....
}
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
} finally {
in.close()
}
}
最侄訓傳的是一個FileSegmentManagedBuffer.最后回傳一個StreamHandle給到客戶端,可以看到發送OpenBlocks只是給資料端生成FileSegmentManagedBuffer,知道需要拉取的是哪些資料,并沒有其他操作,
然后真正拉取資料則是發送ChunkFetchRequest請求,我們看下是怎么處理的:
在TransportRequestHandler會對這些請求進行處理:
public void handle(RequestMessage request) {
if (request instanceof ChunkFetchRequest) {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
} else if (request instanceof OneWayMessage) {
processOneWayMessage((OneWayMessage) request);
} else if (request instanceof StreamRequest) {
processStreamRequest((StreamRequest) request);
} else if (request instanceof UploadStream) {
processStreamUpload((UploadStream) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
}
我們先來看下ChunkFetchRequest處理:
private void processFetchRequest(final ChunkFetchRequest req) {
long chunksBeingTransferred = streamManager.chunksBeingTransferred();
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
chunksBeingTransferred, maxChunksBeingTransferred);
channel.close();
return;
}
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
return;
}
streamManager.chunkBeingSent(req.streamChunkId.streamId);
respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> {
streamManager.chunkSent(req.streamChunkId.streamId);
});
}
這里首先會判斷當前steam的資料是否已經拉取完畢,如果拉取完畢直接關閉通道,然后會獲取對應chunk塊對應的ManagedBuffer,我們上面知道,這里回傳的就是一個FileSegmentManagedBuffer,但是我們詳細看這個Buffer,并沒有任何資料相關,那么資料是怎么讀取傳輸回去的呢 ?關鍵就在Netty的編解碼中,這里資料端開啟server是通過NettyBlockTransferService,其創建createServer方法最后生成一個TransportServer,初始化的時候會呼叫init方法,其初始化Netty的pipline如下:
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
.addLast("handler", channelHandler);
return channelHandler;
} catch (RuntimeException e) {
throw e;
}
}
這里的編碼實作為:MessageEncoder,在其encode方法中,會呼叫FileSegmentManagedBuffer.convertToNetty方法:
public Object convertToNetty() throws IOException {
if (conf.lazyFileDescriptor()) {
return new DefaultFileRegion(file, offset, length);
} else {
FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
return new DefaultFileRegion(fileChannel, offset, length);
}
}
可以看到這里回傳的是一個DefaultFileRegion,還是沒有將檔案轉化成流,繼續看MessageEncoder.encode:
public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
Object body = null;
long bodyLength = 0;
boolean isBodyInFrame = false;
if (in.body() != null) {
try {
bodyLength = in.body().size();
body = in.body().convertToNetty();
isBodyInFrame = in.isBodyInFrame();
} catch (Exception e) {
in.body().release();
if (in instanceof AbstractResponseMessage) {
AbstractResponseMessage resp = (AbstractResponseMessage) in;
// Re-encode this message as a failure response.
String error = e.getMessage() != null ? e.getMessage() : "null";
logger.error(String.format("Error processing %s for client %s",
in, ctx.channel().remoteAddress()), e);
encode(ctx, resp.createFailureResponse(error), out);
} else {
throw e;
}
return;
}
}
Message.Type msgType = in.type();
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
header.writeLong(frameLength);
msgType.encode(header);
in.encode(header);
if (body != null) {
out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
} else {
out.add(header);
}
}
最后回傳的out中是一個MessageWithHeader,而MessageWithHeader實作了netty的FileRegion介面,當進行網路傳輸的時候,會呼叫FileRegion.transferTo方法,在MessageWithHeader實作如下:
public long transferTo(final WritableByteChannel target, final long position) throws IOException {
Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
long writtenHeader = 0;
if (header.readableBytes() > 0) {
writtenHeader = copyByteBuf(header, target);
totalBytesTransferred += writtenHeader;
if (header.readableBytes() > 0) {
return writtenHeader;
}
}
long writtenBody = 0;
if (body instanceof FileRegion) {
writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
} else if (body instanceof ByteBuf) {
writtenBody = copyByteBuf((ByteBuf) body, target);
}
totalBytesTransferred += writtenBody;
return writtenHeader + writtenBody;
}
最后還是呼叫FileRegion.transferTo這里就是我們上面生成的DefaultFileRegion,其實作就是通過零拷貝將檔案中內容傳輸到網路中,到此資料就完成了傳輸
接下來我們在回到資料拉取端,上面資料端回傳了一個ChunkFetchSuccess,然后在拉取端TransportResponseHandler進行處理:
if (message instanceof ChunkFetchSuccess) {
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, getRemoteAddress(channel));
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
resp.body().release();
}
}
這里的lister就是上面我們傳入的ChunkCallback,onsuccess方法如下:
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
}
這里的litener使我們前面傳入,實作如下:
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
}
}
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
可以看到,這里沒有做任何特殊處理,只是通過回傳的結果,實體化了一個SuccessFetchResult,這里private[this] val results = new LinkedBlockingQueue[FetchResult]只是一個LinkedBlockingQueue快取,
好,到這里我們拉取資料分析了這么多,但是資料并沒有進行實際的拉取,那么在什么時候拉取的呢?
這里的拉取資料實作是一個ShuffleBlockFetcherIterator,在其迭代方法next實作中實作了資料拉取:
override def next(): (BlockId, InputStream) = {
if (!hasNext) {
throw new NoSuchElementException
}
numBlocksProcessed += 1
var result: FetchResult = null
var input: InputStream = null
while (result == null) {
val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
}
shuffleMetrics.incRemoteBlocksFetched(1)
}
if (!localBlocks.contains(blockId)) {
bytesInFlight -= size
}
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
if (buf.size == 0)
throwFetchFailedException(blockId, address, new IOException(msg))
}
val in = try {
buf.createInputStream()
} catch {
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
logError("Failed to create input stream from local block", e)
buf.release()
throwFetchFailedException(blockId, address, e)
}
var isStreamCopied: Boolean = false
try {
input = streamWrapper(blockId, in)
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
isStreamCopied = true
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
Utils.copyStream(input, out, closeStreams = true)
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
}
} catch {
case e: IOException =>
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
} else {
corruptedBlocks += blockId
fetchRequests += FetchRequest(address, Array((blockId, size)))
result = null
}
} finally {
if (isStreamCopied) {
in.close()
}
}
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
}
fetchUpToMaxBytes()
}
currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
}
這里將接收到的資料寫入到了ChunkedByteBufferOutputStream中,然后將輸出流改變為輸入流回傳給上游,
這里我們之前分析的上游就是BlockStoreShuffleReader,在其read方法中會迭代呼叫上述資料,執行聚合算子Aggregator中插入到一個Map中:
def combineValuesByKey(
iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
var curEntry: Product2[K, V] = null
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
}
while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}
override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
val newValue = super.changeValue(key, updateFunc)
super.afterUpdate()
newValue
}
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = updateFunc(haveNullValue, nullValue)
haveNullValue = true
return nullValue
}
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(null)) {
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
} else if (k.eq(curKey) || k.equals(curKey)) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
null.asInstanceOf[V] // Never reached but needed to keep compiler happy
}
這里插入到map的時候,會根據傳入的算子對資料進行聚合運算,
分析到這里,我們簡單總結一下:
- 當我們RDD遇到類似reparation這種算子的時候,通過
BlockStoreShuffleReader去讀取shffle資料 - BlockStoreShuffleReader首先回去master獲取當前節點需要拉取的資料
- 然后通過ShuffleBlockFetcherIterator去進行資料拉取
- ShuffleBlockFetcherIterator首先會區分其他節點和本地節點資料,本地節點資料直接讀取,其他節點需要通過網路傳輸
- ShuffleBlockFetcherIterator獲取其他節點資料發送
FetchRequest(通過NettyBlockTransferService來獲取資料),在發送FetchRequest之前,首先會發送OpenBlocks請求(通過OneForOneBlockFetcher),回傳的回應資料中會給出需要拉取資料的相關資訊 - 資料端收到
OpenBlocks請求后,會根據請求中資料資訊獲取相關索引檔案,獲取索引檔案中對應的要拉取檔案的位移,生成FileSegmentManagedBuffer集合,同時封裝成一個StreamHandle回傳給客戶端,StreamHandle相當于是一個包含了此次資料傳輸會話資訊 - 拉區段收到回傳的資訊后開始發送
FetchRequest給資料端 - 資料端收到
FetchRequest之后,根據StreamHandle的資訊找到之前OpenBlocks請求生成的FileSegmentManagedBuffer集合,回傳給客戶端,這里需要注意的是,回傳FileSegmentManagedBuffer會通過單獨的MessageEncoder來進行處理,最后是轉換成了Netty檔案傳輸 - 拉取端獲取到資料后,根據相應的算子把資料放入到一個Map中
- .結束
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/342044.html
標籤:其他
