嗨,我有一批影像,我需要將其劃分為不重疊的補丁,并通過 softmax 函式發送每個補丁,然后重建原始影像。我可以按如下方式制作補丁:
@tf.function
def grid_img(img,patch_size=(256, 256), padding="VALID"):
p_height, p_width = patch_size
batch_size, height, width, n_filters = img.shape
p = tf.image.extract_patches(images=img,
sizes=[1,p_height, p_width, 1],
strides=[1,p_height, p_width, 1],
rates=[1, 1, 1, 1],
padding=padding)
new_shape = list(p.shape[1:-1]) [p_height, p_width, n_filters]
p = tf.keras.layers.Reshape(new_shape)(p)
return p
但是我不知道如何批量重建原始影像。對原始批次的簡單重塑不起作用。資料不會以正確的方式排列。我將不勝感激任何幫助。謝謝
uj5u.com熱心網友回復:
IIUC,您應該能夠簡單地使用tf.reshape批量補丁重建原始影像:
import tensorflow as tf
samples = 5
images = tf.random.normal((samples, 256, 256, 3))
@tf.function
def grid(images):
img_shape = tf.shape(images)
batch_size, height, width, n_filters = img_shape[0], img_shape[1], img_shape[2], img_shape[3]
patches = tf.image.extract_patches(images=images,
sizes=[1, 32, 32, 1],
strides=[1, 32, 32, 1],
rates=[1, 1, 1, 1],
padding='VALID')
return tf.reshape(tf.nn.softmax(patches), (batch_size, height, width, n_filters))
patches = grid(images)
print(patches.shape)
# (5, 256, 256, 3)
更新 1:如果您想以正確的順序重建影像,您可以計算梯度,tf.image.extract_patches如此代碼
uj5u.com熱心網友回復:
我想到的一個骯臟的作業是在轉換后跟蹤細胞的位置。不像@alonetogether 回答那么優雅,但仍然可能有助于分享。
import numpy as np
import tensorflow as tf
@tf.function
def grid(images, grid_size=(32, 32)):
grid_height, grid_width = grid_size
patches = tf.image.extract_patches(images=images,
sizes=[1, grid_height, grid_width, 1],
strides=[1, grid_height, grid_width, 1],
rates=[1, 1, 1, 1],
padding='VALID')
return patches
batch_size, height, width, n_filters = shape = (5, 256, 256, 1)
indices = tf.range(batch_size * height * width * n_filters)
images = tf.reshape(indices, (batch_size, height, width, n_filters ))
patches = grid(images)
transfered_indices = tf.reshape(patches, shape=[-1])
tracked_indices = tf.argsort(transfered_indices) # Indices after transformation, Save this
images = tf.random.normal(shape)
patches = grid(images)
flatten_patches = tf.reshape(patches, shape=[-1])
reconstructed = tf.reshape(tf.gather(flatten_patches, tracked_indices), shape)
np.alltrue(reconstructed==images) # True
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/456245.html
