import os
import numpy as np
import nibabel as nib


def gen_mask():
    data = None
    for file in os.listdir():
        if file.startswith('group_'):
            data = nib.load(file)
            break

    assert data is not None
    assert data.ndim == 3

    z_max = data.shape[2]
    y_max = data.shape[1]
    x_max = data.shape[0]
    mask = np.zeros(data.shape, dtype=np.uint8)

    for z in range(z_max):
        for y in range(y_max):
            for x in range(x_max):
                if os.path.exists(f'group_{x+1}_{y+1}_{z+1}_Z.nii.gz'):
                    mask[x][y][z] = 1

    image = nib.Nifti1Image(mask, affine=data.affine)

    nib.save(image, 'mask.nii.gz')


if __name__ == '__main__':
    gen_mask()
