# -*- coding: utf-8 -*-
"""
Created on Tue Mar 31 11:49:54 2020

@author: Glenn Viveen
"""

import numpy as np
from scipy.special import comb

def clebsch_gordan(J):
    qnarray = np.copy(J)
    if qnarray.shape != (2, 3):
        raise ValueError("J must be of shape (2, 3)")
    qnarray[1, 2] = -qnarray[1, 2]
    
    j1 = qnarray[0,0]
    j2 = qnarray[0,1]
    j3 = qnarray[0,2]
    m1 = qnarray[1,0]
    m2 = qnarray[1,1]
    m3 = qnarray[1,2]
    
    if (m1 < -j1 or m1 > j1 ):
        raise ValueError("m1 out of range. m1 = {}, j1= {}".format(m1, j1))
    if (m2 < -j2 or m2 > j2 ):
        raise ValueError("m2 out of range. m2 = {}, j2= {}".format(m2, j2))
    if (m3 < -j3 or m3 > j3 ):
        raise ValueError("m3 out of range. m3 = {}, j3= {}".format(m3, j3))
    
    if (j1%0.5 != 0 or (j1-np.abs(m1))%1 != 0):
        raise ValueError("j1 and m1 must both be a (half-)integer.")
    if (j2%0.5 != 0 or (j2-np.abs(m2))%1 != 0):
        raise ValueError("j2 and m2 must both be a (half-)integer.")
    if (j3%0.5 != 0 or (j3-np.abs(m3))%1 != 0):
        raise ValueError("j3 and m3 must both be a (half-)integer.")
        
    if (m1+m2+m3) != 0:
        return 0
    
    i1 = np.round(-j1+j2+j3)
    if i1<0:
        return 0
    i2 = np.round(j1-j2+j3)
    if i2<0:
        return 0
    i3 = np.round(j1+j2-j3)
    if i3<0:
        return 0
    k1 = np.round(j1+m1)
    if k1<0:
        return 0
    k2 = np.round(j2+m2)
    if k2<0:
        return 0
    k3 = np.round(j3+m3)
    if k3<0:
        return 0
    l1 = np.round(j1-m1)
    if l1<0:
        return 0
    l2 = np.round(j2-m2)
    if l2<0:
        return 0
    l3 = np.round(j3-m3)
    if l3<0:
        return 0
    n1 = np.round(-j1-m2+j3)
    n2 = np.round(m1-j2+j3)
    n3 = np.round(j1-j2-m3)        
    
    imin = int(np.round(np.max(np.array([-n1, -n2, 0]))))
    imax = int(np.round(np.min(np.array([l1, k2, i3]))))
    
    if imin > imax:
        return 0
        
    sign = -1
    tj = 0
    
    for i in range(imin, imax+1):
        sign = -sign
        tj += sign * comb(i1, n1+i) * comb(i2, n2+i) * comb(i3, i)
    
    tj *= np.sqrt(comb(2*j2, i3)*comb(2*j1, i2)/
                 (comb(j1+j2+j3+1, i3)*(2*j3+1)*
                  comb(2*j1, l1)*comb(2*j2, l2)*
                  comb(2*j3, l3)))
    
    if (n3+imin)%2 != 0:
        tj = -tj
    
    fac = np.power(-1, np.abs(j1-j2-m3))*np.sqrt(2*j3+1)
    
    cg_m = fac*tj
        
    return cg_m