#!/usr/bin/python
################################################################################
# Create a mini custom database
# Support the following: 
# db|use dbname, show databases|tables, desc tablename 
# create table tablename (colname1, colname2, etc) 
# insert into tablename values(colvalue1, 'colvalue2', etc.) 
# select * | colnames from table1 join table2 on table1.col1 = table2.col2
################################################################################
import os, sys, re

location = '/home/sultans/web/sql/demo/SamDB/db/'               #o/s directory to store database files
dbname   = 'data'                                               #default subdirectory/database name
data = []                                                       #global 2dim list to hold select results

################################################################################
# Identify SQL command
################################################################################
def process(cmd2):
    if   cmd2.startswith('db')    : db()
    elif cmd2.startswith('use')   : db()
    elif cmd2.startswith('show')  : show()
    elif cmd2.startswith('desc')  : desc()
    elif cmd2.startswith('create'): create()
    elif cmd2.startswith('insert'): insert()
    elif cmd2.startswith('select'): select()
    else:
        print('Error: Invalid command')
    print()

################################################################################
# db dbname (or) use dbname - Change to a new database
################################################################################
def db():
    global dbname
    
    ### Parse use command --------------------------------------
    try:
        found = re.search(r'^\s*(db|use)\s+(\w+)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Invalid db or use command -> db dbname")
            return False
        dbname  = found.group(2).lower()            # the database name
        dirname = location + dbname 
        ok = os.path.exists(dirname)                # ensure directory exists
        if (not ok):
            print("Error: Database", dbname, "does not exist")
            return False
        print("Switching to database", dbname)
    except:
        print("Error accessing", dbname, "database")
        return False

################################################################################
# Show databases | tables - Display the list of databases or tables
################################################################################
def show():

    ### Parse show command --------------------------------------
    try:
        found = re.findall(r'^\s*show\s+(databases|tables)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Invalid show command -> show databases | tables")
            return False
        type = found[0].lower()                                 #either databases or tables
        dirname = location  
        if type == 'tables':        
            dirname += dbname 
        files = os.listdir(dirname)                             #list all files in the directory
        files.sort()                                            #sort the list
        if (type=='tables' and not files):
            print("Error: No tables in", dbname, "database")
            return False               
        print(type.upper())                                     #DATABASES or TABLES
        for file in files:
            print(file)
    except:
        print("Error accessing", dbname, "database")
        return False

################################################################################
# Desc tablename - Display list of columns within a table
################################################################################
def desc():

    ### Parse desc command --------------------------------------
    try:
        found = re.findall(r'^\s*desc\s+(\w+)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Invalid desc command -> desc tablename")
            return False
        table = found[0].lower()
        filename = location + dbname +'/'+ table
        if not os.path.exists(filename):            #file must exist
            print("Error: Table", table, "does not exist")
            return False
        print("COLUMNS")
        with open(filename, 'r') as file:
            line = file.readline()
            line = line.rstrip()
            cols = line.split(', ')
            for col in cols:
                print(col)
            file.close()
    except:
        print("Error accessing", table)
        return False
        
################################################################################
# Create table tablename (col1, col2, etc.)
################################################################################
def create():

    ### Parse create command --------------------------------------
    try:
        found = re.findall(r'^\s*create\s+table\s+(\w+)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Invalid create command -> create table tablename")
            return False
        table = found[0].lower()
        found = re.findall(r'\((.+)\)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Please provide column names -> create table", table, "(col1, col2, etc.)")
            return False
        columns = re.split(r',\s*', found[0])
        for col in columns:
            col = col.rstrip().lower()
    except:
        print('Error: Invalid create command')
        return False

    ### Create the table --------------------------------------
    try:
        filename = location + dbname +'/'+ table
        file = open(filename, 'x')                      #open exclusive                 
        file.close()
        print("Table", table, "created ", end="")
    except FileExistsError:
        print("Error: Table already exists")
        return False

    ### Write column headers ----------------------------------
    try:
        columnList = ", ".join(columns) 
        with open(filename, 'w') as file:               #open for write                    
            file.write(columnList + '\n')               #write col headers
            print("with columns", columnList)
            file.close()
    except:
        print("Error: Could not write columns into table", table)
        return False

################################################################################
# Insert into tablename values(val1, 'val2', etc.)
################################################################################
def insert():

    ### Parse insert command --------------------------------------
    try:
        found = re.findall(r'^\s*insert\s+into\s+(\w+)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Invalid insert command -> insert into tablename")
            return False
        table = found[0].lower()
        found = re.findall(r'values\s*\((.+)\)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Please provide column values -> insert into", table, "values(val1, val2, etc.)")
            return False
        values = re.split(r',\s*', found[0])        #column values 
        for val in values:
            val = val.rstrip()                      #strip trailing space if any
            if not val.isnumeric():                 #if value is alpha
                found = re.findall(r"^'.+'$",val)   #'value' must be in quotes
                if (not found):
                    print("Error: Column value", val, "must be enclosed in single quotes")
                    return False              
    except:
        print('Error: Invalid insert command')
        return False

    ### Retrieve column headers --------------------------------
    try:
        filename = location + dbname +'/'+ table
        if os.path.exists(filename):                #file must exist
            with open(filename, 'r') as file:       #open file for read                    
                columnList = file.readline()        #col headers is line1
                columns    = columnList.split(', ')
                file.close()
        else:
            print("Error: Table", table, "does not exist")
            return False
    except:
        print('Error: Could not retrieve column headers')
        return False

    ### Write new table row --------------------------------
    try:
        len_col = len(columns)                      #num of cols in table
        len_val = len(values)                       #num of values provided
        if len_val != len_col:
            print("Error: Column values do not match number of columns", len_col)
            return False
        valueList = ", ".join(values)
        valueList = valueList.replace("'", "")      #eliminate single quotes

        with open(filename, 'a') as file:           #open file for append                    
            file.write(valueList + '\n')            #write the new data row
            print("Row: [", valueList, "] inserted")
            file.close()
    except:
        print('Error: Could not insert row in table')
        return False

################################################################################
# select * | col1,col2,etc. from tablename
################################################################################
def select():

    ### Parse select command --------------------------------------
    try:
        found = re.findall(r'^\s*select\s+(\*|.+)\s+from',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Please provide * or column names from tablename")
            return False
        reqColumns = re.split(r',\s*', found[0])
        for col in reqColumns:
            col = col.rstrip().lower()
        found = re.findall(r'from\s+(\w+)',cmd,re.IGNORECASE)
        if (not found):
            print("Error: Please provide a table name")
            return False
        table = found[0].lower()
    except:
        print('Error: Invalid select command')
        return False

    ### Determine select clauses --------------------------------
    found = re.findall('join',cmd,re.IGNORECASE)
    if (not found):
        ok = select_single(table)
    else:
        ok = select_join(table)

    found = re.findall('where',cmd,re.IGNORECASE)
    if (found and ok):
        ok = select_where()

    if (ok): display(reqColumns) 
    
################################################################################
# Select single table, no join, no where 
################################################################################
def select_single(table):
    global data
    try:
        filename = location + dbname +'/'+ table
        if not os.path.exists(filename):            #file must exist
            print("Error: Table", table, "does not exist")
            return False
        data.clear()
        with open(filename, 'r') as file:
            for line in file:                       #read file (table)
                line = line.rstrip()                #remove \n char
                cols = line.split(', ')             #split 
                data.append(cols)                   #save in global data list                       
            file.close()
        return True
    except:
        print('Error: Could not retrieve data from table')
        return False

################################################################################
# Select join tables 
################################################################################
def select_join(table1):
    global data

    found = re.search(r'\sjoin\s+(\w+?)\s+on\s+(\w+\.\w+)\s*=\s*(\w+\.\w+)',cmd,re.IGNORECASE)
    if (not found):
        print("Error: Please provide table1 join table2 on table1.col = table2.col")
        return False
    table2   = found.group(1)               #first search part  -> table2
    tbl1_col = found.group(2)               #second search part -> table1.col
    tbl2_col = found.group(3)               #third  search part -> table2.col
    (tbl1, col1) = tbl1_col.split('.')      #split table1.col on .
    (tbl2, col2) = tbl2_col.split('.')

    if  (tbl1 !=table1 and tbl1 !=table2) or \
        (tbl2 !=table1 and tbl2 !=table2):
        print("Error: Table names in ON clause do not match select tables")
        return False

    if (tbl1 != table1):                #align the order of table1 JOIN table2
        temp = tbl1                     #with the ON table1.col1 and table2.col2
        tbl1 = tbl2                     #if not aligned properly, switch them      
        tbl2 = tbl1
        temp = col1
        col1 = col2
        col2 = temp
        
    try:
        filename1 = location + dbname +'/'+ table1
        if not os.path.exists(filename1):                #file must exist
            print("Error: Table", table1, "does not exist")
            return False
        filename2 = location + dbname +'/'+ table2
        if not os.path.exists(filename2):                #file must exist
            print("Error: Table", table2, "does not exist")
            return False

        data1 = []                              #to hold table1
        data2 = []                              #to hold table2
        data.clear()                            #clear global data list
        with open(filename1, 'r') as file:
            for line in file:                   #read file1
                line = line.rstrip()            #remove \n char
                cols = line.split(', ')         #split
                data1.append(cols)              #save in 2dim data1 list
            file.close()
        with open(filename2, 'r') as file:
            for line in file:                   #read file2
                line = line.rstrip()            #remove \x char
                cols = line.split(', ')         #split
                data2.append(cols)              #save in 2dim data2 list
            file.close()

        if (col1 not in data1[0] or col2 not in data2[0]):
            print("Error: Column names in ON clause do not exist")
            return False

        for i in range(len(data1[0])):      #determine col to use as merge key1
            if data1[0][i] == col1:
                break
        for j in range(len(data2[0])):      #determine col to use as merge key2
            if data2[0][j] == col2:
                break                   
        merge = data1[0] + data2[0]         #merge the column headers
        data.append(merge)                  #append to global data list
        
        for row1 in data1:
            for row2 in data2:
                if row1[i] == row2[j]:      #if values of the keys are equal
                    merge = row1 + row2     #merge the 2 rows
                    data.append(merge)      #append to global data list      

        if (data[0]==data[1]):              #if duplicate column headers
            data[1][0] = 'no_print'         #remove the 2nd line

        return True
    except:
        print('Error: Could not retrieve data from tables')
        return False

################################################################################
# Select with where clause (supports: = != > < >= <= )
################################################################################
def select_where():
    global data

    found = re.search(r'\swhere\s+(.+?)\s*(=|!=|>=|<=|>|<)\s*(.*)',cmd,re.IGNORECASE)
    if (not found):
        print("Error: Please provide select ... where col=value")
        return False
    whereCol = found.group(1)               #search part1 -> colname
    whereOpr = found.group(2)               #search part2 -> operator
    whereVal = found.group(3)               #search part3 -> colvalue
    whereVal = whereVal.replace("'", "")    #eliminate single quotes
    colHeaders = data[0]                    #headers are first row

    if whereCol not in colHeaders:
        print("Error: Invalid column name", whereCol, "in where clause")
        return False
    
    for i in range(1, len(data)):
        for j in range(len(data[i])):
            if whereCol == colHeaders[j]:
                if   whereOpr == '='  and data[i][j] == whereVal: pass      #keep the record
                elif whereOpr == '!=' and data[i][j] != whereVal: pass
                elif whereOpr == '>'  and data[i][j] >  whereVal: pass
                elif whereOpr == '<'  and data[i][j] <  whereVal: pass
                elif whereOpr == '>=' and data[i][j] >= whereVal: pass
                elif whereOpr == '<=' and data[i][j] <= whereVal: pass
                else: data[i][j] = 'no_print'                               #mark it as no_print                
    return True

################################################################################
# Display * or requested columns only
################################################################################
def display(reqColumns):

    count=0
    colHeaders = data[0]                        #column headers are row 1

    for col in reqColumns:
        if (col !='*' and (col not in colHeaders)):
            print("Error: Column", col, "does not exist")
            return False

    for i in range(len(data)):
        if 'no_print' in data[i]: continue      #skip printing
        for j in range(len(data[i])):
            if (reqColumns[0]=='*' or colHeaders[j] in reqColumns):
                if (i==0): print(data[i][j].upper(),'\t',end='')    #print col headers
                else     : print(data[i][j],        '\t',end='')    #print data
        print()
        count +=1
    print("Total number of rows:", count-1)     #count -header row

################################################################################
# Main
# - Accept a SQL stmt on the command line:  [db=dbname] sql statement
# - Or, continuously prompt the user to enter SQL statements 
################################################################################
if len(sys.argv) > 1:               #if SQL from command line
    sys.argv.pop(0)                 #get rid of element1 = name of this script
    cmd  = ' '.join(sys.argv)       #concatenate all remaining 
    if 'db=' in cmd:
        space  = cmd.index(' ')     #find position of first space
        dbname = cmd[3:space]       #the requested database name
        cmd    = cmd[space+1:]      #the rest of the SQL command
    cmd2 = cmd.strip().lower()
    process(cmd2)
else:                                           #interactive mode
    while(True):
        cmd  = input("Enter SQL command: ")     #user prompt
        cmd2 = cmd.strip().lower()
        if cmd2 in ['exit','quit','q']:
            print('Exiting')
            sys.exit(0)
        process(cmd2)

################################################################################