from math import sqrt
import numpy as np
import matplotlib.pyplot as plt
import argparse

def rec_lat(lat):
    """given a lattice return the reciprocal lattice"""
    v = vol_lat(lat)
    a1,a2,a3 = np.array(lat)
    b1 = np.cross(a2,a3)/v
    b2 = np.cross(a3,a1)/v
    b3 = np.cross(a1,a2)/v
    return np.array([b1,b2,b3])

def vol_lat(lat):
    """return the volume of the lattice"""
    a1,a2,a3 = np.array(lat)
    return np.dot(a1,np.cross(a2,a3))

def hexagon(ax,c,r):
    """plot an hexagon with center c and radius r with matplotlib"""
    x,y = c
    #generate hexagon points
    points = []
    for i in range(7):
        angle = i*2*np.pi/6
        xy = [r*np.cos(angle)+x,r*np.sin(angle)+y]
        points.append(xy)

    #plot them
    for i in range(len(points)-1):
        x1,y1 = points[i]
        x2,y2 = points[i+1]
        ax.plot([x1, x2], [y1, y2], 'r', linestyle='-')

def lat2d(ax,lat):
    """draw 2D lattice"""
    x,y   = [0,0]
    dx,dy = lat[1]
    ax.plot([x, x+dx], [y, y+dy], 'k', linestyle='--',lw=1)
    x,y   = [0,0]
    dx,dy = lat[0]
    ax.plot([x, x+dx], [y, y+dy], 'k', linestyle='--',lw=1)
    x,y   = lat[0]
    dx,dy = lat[1]
    ax.plot([x, x+dx], [y, y+dy], 'k', linestyle='--',lw=1)
    x,y   = lat[1]
    dx,dy = lat[0]
    ax.plot([x, x+dx], [y, y+dy], 'k', linestyle='--',lw=1)

def plot_commensurate(ax,lat,N):

    lat = np.array(lat)
    lat3d = np.eye(3)
    lat3d[0:2,0:2] = lat

    rlat3d = rec_lat(lat3d)
    rlat = rlat3d[0:2,0:2]

    #draw hexagons
    x,y = lat[0]+lat[1]
    #r = 2/3
    r = np.linalg.norm(rlat[0]+rlat[1])/3
    hexagon(ax,[0,0],r)
    hexagon(ax,rlat[0]+rlat[1],r)
    hexagon(ax,rlat[0],r)
    hexagon(ax,rlat[1],r)

    #draw Gamma centered grid
    points = []
    for ix in range(N):
        for iy in range(N):
            point_red = [ix/N,iy/N]
            point_car = np.dot(rlat.T,point_red)
            points.append(point_car)

    #draw path
    path_red  = [[  0.0,  0.0],
                 [1.0/2,  0.0],
                 [1.0/3,1.0/3],
                 [  0.0,  0.0]]
    path_labels=['$\Gamma$','M','K','$\Gamma$']

    #plot reciprocal lattice
    lat2d(ax,rlat)

    # plot the q point path
    path_car = [np.dot(rlat.T,point_red) for point_red in path_red]
    for i in range(len(path_car)-1):
        x1,y1 = path_car[i]
        x2,y2 = path_car[i+1]
        ax.plot([x1, x2], [y1, y2], 'b', linestyle='-')

    # draw k point labels
    for point,label in zip(path_car,path_labels):
        x,y = point
        ax.text(x,y,label)

    points=np.array(points)
    ax.scatter(points[:,0],points[:,1])

    ax.set_aspect('equal')

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('-n', action='store', type=int, help='number of repetitions of the supercell')
    args = parser.parse_args()

    #hexagonal lattice
    lat = np.array([[   1.0,      0.0],
                    [-1.0/2,sqrt(3)/2]])

    fig,ax = plt.subplots()
    fig.patch.set_visible(False)
    ax.axis('off')

    N = args.n
    plot_commensurate(ax,lat,N)

    fig.set_size_inches(5, 5)
    fig.savefig('%dx%d.pdf'%(N,N))
    plt.show()

