一、concat函数介绍
在numpy中,concat函数可以用于沿特定轴连接两个或多个数组。
np.concatenate((a1, a2, ...), axis=0, out=None)
参数axis指示了沿哪个轴连接数组。如果没有指定,np.concatenate默认将沿着第一个维度(即axis=0)进行连接。
二、在第一个轴上连接数组
当输入参数里所有数组的shape在第一个轴上的大小相同时,我们可以通过np.concatenate将它们在第一个轴上连接起来。
import numpy as np
arr1 = np.array([[1, 2], [3, 4]])
arr2 = np.array([[5, 6], [7, 8]])
arr3 = np.array([[9, 10], [11, 12]])
result = np.concatenate((arr1, arr2, arr3))
print(result)
输出结果为:
[[ 1 2]
[ 3 4]
[ 5 6]
[ 7 8]
[ 9 10]
[11 12]]
三、在其他轴上连接数组
当我们需要连接的数组shape不同的轴时,可以通过np.concatenate指定轴号,来沿其他轴对数组进行连接。
arr1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
arr2 = np.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
# 沿第一个轴连接
result1 = np.concatenate((arr1, arr2), axis=0)
# 沿第三个轴连接
result2 = np.concatenate((arr1, arr2), axis=2)
print("沿第一个轴连接结果:\n", result1)
print("沿第三个轴连接结果:\n", result2)
输出结果为:
沿第一个轴连接结果:
[[[ 1 2]
[ 3 4]]
[[ 5 6]
[ 7 8]]
[[ 9 10]
[11 12]]
[[13 14]
[15 16]]]
沿第三个轴连接结果:
[[[ 1 2 9 10]
[ 3 4 11 12]]
[[ 5 6 13 14]
[ 7 8 15 16]]]
四、使用out参数避免数组复制
在进行大量数组连接操作时,numpy会创建一个新的数组来存储最终结果,这将导致不必要的内存复制。可以通过指定参数out来避免这种情况。
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
z = np.array([7, 8, 9])
# 指定输出数组
out = np.zeros(9)
np.concatenate([x, y, z], out=out)
print(out)
输出结果为:
[1. 2. 3. 4. 5. 6. 7. 8. 9.]
五、使用stack函数进行堆叠操作
除了concatenate函数,numpy还提供了stack函数,不同之处在于,stack函数会将输入的数组沿新的轴方向堆叠起来。
arr1 = np.array([1, 2, 3])
arr2 = np.array([4, 5, 6])
arr3 = np.array([7, 8, 9])
# 沿新轴(第一轴)方向堆叠数组
result = np.stack((arr1, arr2, arr3))
print(result)
输出结果为:
[[1 2 3]
[4 5 6]
[7 8 9]]
六、总结
Numpy的concatenate函数提供了沿特定轴连接两个或多个数组的功能。通过指定参数axis可以选择不同轴向进行连接。如果需要避免不必要的内存复制,可以使用参数out来指定输出数组。stack函数则提供了将输入的数组沿新的轴方向堆叠起来的功能。