给mpi4py写了个wrapper。包括并行写入,对于numpy array的split并scatter,bcast和gather,基本完成。如果有新想法应该会持续更新,加入新功能。
import h5py as h5
from mpi4py import MPI
import time
import numpy as np
mpi_comm = MPI.COMM_WORLD
mpi_size = mpi_comm.Get_size()
mpi_rank = mpi_comm.Get_rank()
def process_size(total_size, rank=mpi_rank, size=mpi_size):
if rank < int(total_size % size):
return int(total_size//size + 1)
else:
return int(total_size//size)
def ind_end(total_size, rank=mpi_rank, size=mpi_size):
all_size = [int(total_size//size + 1)]* int(total_size % size)
#print(total_size, all_size)
all_size += [int(total_size//size)]* (total_size - int(total_size % size))
#print(size, all_size)
return np.cumsum(all_size)[rank]
def ind_start(total_size, rank=mpi_rank, size=mpi_size):
return ind_end(total_size, rank=rank, size=size) - process_size(total_size, rank=rank, size=size)
def paralle_save_dataset(filename, key, data, axis=0):
data = np.asarray(data)
shp = list(data.shape)
num = shp[axis]
len_axis = mpi_comm.gather(num, root=0)
if mpi_rank == 0:
len_axis = sum(len_axis)
len_axis = mpi_comm.bcast(len_axis, root=0)
ist = ind_start(len_axis)
ied = ind_end(len_axis)
save_slice = [slice(None,None,None)]*len(shp)
save_slice[axis] = slice(ist, ied, None)
save_slice = tuple(save_slice)
shp[axis] = len_axis
if mpi_rank == 0:
with h5.File(filename, 'a') as filein:
filein.create_dataset(key, shape=shp, dtype=data.dtype)
for ii in range(mpi_size):
if ii == mpi_rank:
for _ in range(10):
try:
#raise IOError
with h5.File(filename, 'a') as filein:
filein[key][save_slice] = data
print('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename))
time.sleep(0.5)
break
except IOError as e:
print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))
time.sleep(0.5)
else:
raise IOError('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename))
mpi_comm.barrier()
def paralle_save_multi_dataset(filename, key, data):
for ii in range(mpi_size):
if ii == mpi_rank:
for _ in range(10):
try:
#raise IOError
with h5.File(filename, 'a') as filein:
filein[key] = data
print('Rank %d save dataset '%s' into %s!'%(mpi_rank, key, filename))
time.sleep(0.5)
break
except IOError as e:
print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))
time.sleep(0.5)
else:
raise IOError('Rank %d cannot save %s into %s!'%(mpi_rank, key, filename))
mpi_comm.barrier()
def split_uneven_array(data, root=0, axis=0):
'''
array_split and then scatter the splitted array
'''
if mpi_rank == root:
data = np.asarray(data)
data = np.array_split(data, mpi_size, axis=axis)
new_data = mpi_comm.scatter(data, root=root)
return new_data
def split_even_array(data, root=0, axis=0):
'''
array_split and then scatter the splitted array
'''
if mpi_rank == root:
data = np.asarray(data)
shp = list(data.shape)
assert shp[axis]%mpi_size==0, 'Axis %d with length %d cannot exactly divided by mpi size %d!'%(axis, shp[axis], mpi_size)
dtype = data.dtype
data = np.array_split(data, mpi_size, axis=axis)
data = np.asarray(data)
else:
dtype = None
shp = None
dtype = mpi_comm.bcast(dtype, root=root)
shp = mpi_comm.bcast(shp, root=root)
shp[axis] = process_size(shp[axis])
new_data = np.empty(shp, dtype=dtype)
mpi_comm.Scatter(data, new_data, root=root)
#new_data = mpi_comm.scatter(data, root=root)
return new_data
def split_array(data, root=0, axis=0):
if mpi_rank == root:
data = np.asarray(data)
shp = list(data.shape)
if shp[axis]%mpi_size==0:
even = True
else:
even = False
else:
even = None
even = mpi_comm.bcast(even, root=root)
if even:
print('Split and scatter as numpy array!')
return split_even_array(data, root=root, axis=axis)
else:
print('Split and scatter as python object!')
return split_uneven_array(data, root=root, axis=axis)
def bcast_array(data, root=0):
if mpi_rank == root:
data = np.asarray(data)
dtype = data.dtype
shp = data.shape
else:
dtype = None
shp = None
dtype = mpi_comm.bcast(dtype, root=root)
shp = mpi_comm.bcast(shp, root=root)
if mpi_rank != root:
data = np.empty(shp, dtype=dtype)
mpi_comm.Bcast(data, root=root)
return data
def gather_array(data, root=0, axis=0, expand_dim=False, ascontiguous=True):
data = np.asarray(data)
shp = list(data.shape)
if expand_dim:
print('Gather as numpy array and expand axis=%d!'%axis)
even = True
new_shp = [mpi_size] + shp
else:
all_shp = mpi_comm.gather(shp, root=root)
all_shp = mpi_comm.bcast(all_shp, root=root)
shp0 = all_shp[0]
even = True
total_len = shp0[axis]
for ii in all_shp[1:]:
assert len(shp0) == len(ii), 'Data from different mpi process should have the same number of dimensions! Shapes are: %s'%all_shp
shp1 = shp0.copy()
shp2 = ii.copy()
del shp1[axis]
del shp2[axis]
assert np.array_equal(shp1, shp2), 'Data from different mpi process should have the same shape except for the merge axis! Shapes are: %s'%all_shp
if ii[axis] != shp0[axis]:
even = False
total_len += ii[axis]
if even:
print('Gather as numpy array!')
new_shp = shp0.copy()
del new_shp[axis]
new_shp = [total_len] + new_shp
else:
print('Gather as python object!')
if even:
if mpi_rank == root:
new_data = np.empty(new_shp, dtype=data.dtype)
else:
new_data = None
mpi_comm.Gather(data, new_data, root=root)
if mpi_rank == root:
new_data = np.moveaxis(new_data, 0, axis)
if ascontiguous:
new_data = np.ascontiguousarray(new_data)
return new_data
else:
new_data = mpi_comm.gather(data, root=root)
if mpi_rank == root:
new_data = np.concatenate(new_data, axis=axis)
return new_data
if __name__ == '__main__':
#if mpi_rank == 1:
# with h5.File('test.hdf5', 'w') as filein:
# pass
# a = np.random.rand(10, 2000, 800)
#else:
# a = None
#
#
#from timeit import timeit
#def c1():
# b = split_even_array(a, root=1, axis=-1)
#def c2():
# b = split_uneven_array(a, root=1, axis=-1)
#
#print(mpi_rank, timeit(c2, number=20), 2)
#print(mpi_rank, timeit(c1, number=20), 1)
#
#
#exit()
#b = split_array(a, root=1, axis=-1)
##a = mpi_comm.bcast(a, root=1)
#a = bcast_array(a, root=1)
#print(mpi_rank, b.shape)
#print(np.abs(a[...,a.shape[-1]//mpi_size*mpi_rank:a.shape[-1]//mpi_size*(mpi_rank+1)] - b).max())
#paralle_save_dataset('test.hdf5', 'a', b, axis=-1)
#if mpi_rank == 0:
# with h5.File('test.hdf5', 'r') as filein:
# print(np.abs(a - filein['a'][:]).max())
#if mpi_rank == 0:
# a = np.random.rand(mpi_size, 30)
# with h5.File('test.hdf5', 'w') as filein:
# pass
#else:
# a = None
#a = mpi_comm.scatter(a, root=0)
#paralle_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a)
axis = 1
expand_dim = False
#a = np.random.rand(10, 3, 20)
np.random.seed(mpi_rank+1)
#a = np.random.rand(10, 3, 20)
a = np.random.rand(10, mpi_rank+1, 20)
print(np.shape(a), mpi_rank)
a = gather_array(a, root=1, axis=axis, expand_dim=expand_dim)
print(np.shape(a), mpi_rank)
if mpi_rank == 1:
b = []
for ii in range(mpi_size):
np.random.seed(ii+1)
#b.append(np.random.rand(10, 3, 20))
b.append(np.random.rand(10, ii+1, 20))
if expand_dim:
b[-1] = np.expand_dims(b[-1], axis=axis)
b = np.concatenate(b, axis=axis)
print(np.abs(a - b).max())



