import os
import re

import nibabel as nib
import numpy as np


def mapper(data, threshold, max_threshold):
    n = len(data)
    rgba = np.zeros((n, 4), dtype='float32')

    for i in range(n):
        if data[i] >= threshold:
            if data[i] > max_threshold:
                rgba[i, 0] = 1
                rgba[i, 1] = 1
                rgba[i, 2] = 0
            else:
                value = (data[i] - threshold) / (max_threshold - threshold)
                rgba[i, 0] = 1
                rgba[i, 1] = value
                rgba[i, 2] = 0
        elif data[i] <= -threshold:
            if data[i] < -max_threshold:
                rgba[i, 0] = 0
                rgba[i, 1] = 1
                rgba[i, 2] = 0
            else:
                value = (-data[i] - threshold) / (max_threshold - threshold)
                rgba[i, 0] = 0
                rgba[i, 1] = value
                rgba[i, 2] = 1 - value
        else:
            rgba[i, 0] = 1
            rgba[i, 1] = 1
            rgba[i, 2] = 1

    return rgba


def process():
    pattern = re.compile('group_(.*)_Z.R.func.gii')

    for file in os.listdir():
        result = re.match(pattern, file)
        if result:
            base = result.group(1)
            func = nib.load(file)
            data = nib.gifti.GiftiDataArray(mapper(func.darrays[0].data, 3.28, 6),
                                            intent='NIFTI_INTENT_RGBA_VECTOR')
            img = nib.gifti.GiftiImage()
            img.darrays = [data]
            nib.save(img, f'group_{base}_Z.R.rgba.gii')


if __name__ == '__main__':
    process()
