#===============================================================================================
# Contains the following methods
# - Connect to Mysql database using encrypted connection string
# - Query data from Mysql database
# - createTable into Mysql database (TBD)
# - dropTable from Mysql database (TBD)
# - Insert data into Mysql table (TBD)
#===============================================================================================
# Connect to a database 
# receives: database name (connection string)
# returns:  connection object
#===============================================================================================
def connect(db_name):
    import sys
    import pymysql as mydb                                              #Mysql 3x database driver
 
    sys.path.insert(0,'/home/s/sultans/web/python/demo/etl/util')
    import encrypt                                                      #import encryption functions

    connect_string = encrypt.get_connection_string(db_name,'mysql')     #connection string     
                                                                        #format: host,user,pswd,db,port
    (hst,usr,pswd,dbname,prt) = connect_string.split(',')
    port = int(prt)    

    try:
        conn = mydb.connect(host=hst,user=usr,password=pswd,db=dbname,port=port)     #connect to database             
    except mydb.Error as e:
        (errorNum, errorMsg) = e.args
        print('Could not connect to database -' , errorNum , errorMsg)
        exit()        

    return(conn)
 
#===============================================================================================
# Query data from Mysql database 
# receives: database name (connection string), and sql query to execute
# returns:  list of dictionaries, format: [{col_name1:col_value1, col_name2:col_value2}, {...}]
#===============================================================================================
def query(db_name, sql):

    conn = connect(db_name)                                    #connect to database
    try:
        cursor = conn.cursor()                                 #create a cursor  
        result = cursor.execute(sql);                          #execute the query
    except:                                                    #if error
        print("Could not execute query")                       #print error message 
        print(sql)                                             #print the offending sql
        exit()
 
    numOfCol = len(cursor.description)          #number of column retrieved 

    col_names = []                              #a list to capture all column names
    for colInfo in cursor.description :         #obtain metadata for the query
        col_name          = colInfo[0]          #the column name
        col_type          = colInfo[1]          #the column type      (not used)
        col_display_size  = colInfo[2]          #display size         (not used)
        col_internal_size = colInfo[3]          #internal size        (not used)
        col_precision     = colInfo[4]          #the column precision (not used)
        col_scale         = colInfo[5]          #the column scale     (not used)
        col_nullable      = colInfo[6]          #is column nullable   (not used)
        col_names.append(col_name)              #add the col_name to a list of col_names

    list_dict = []                              #create a list

    row = cursor.fetchone()                        #get one row at a time (a single tuple)
    while row is not None:                         #loop until no more rows
        i=0
        dict = {}                                  #re-create the dictionary 
        for col_value in row :                     #for every column in the row 
            col_name = col_names[i]                #the column name from above
            if col_value == None:                  #if column is empty (Python None)
                col_value = ''
            dict[col_name] = str(col_value)        #add col_name/col_value to the dictionary        
            i += 1
        list_dict.append(dict)                     #add the dictionary to the list 
        row = cursor.fetchone()                    #get next row

    cursor.close()                                 #close the cursor/buffer
    conn.close()                                   #close the connection

#   print(list_dict)                               #print (debug)
    return(list_dict)                              #return list of dictionaries

#===============================================================================================
# Drop/delete table 
# receives: database name (connection string), table name
# returns:  None
#===============================================================================================
def dropTable(db_name, table_name):

    sql  ='DROP TABLE ' + table_name          

    conn = connect(db_name)                                    #connect to database
    try:
        cursor = conn.cursor()                                 #create a cursor  
        cursor.execute(sql);                                   #execute the query
        conn.commit();                                         #commit to database

    except Exception as e:
        (errorNum, errorMsg) = e.args
        print("Could not drop table -" , errorNum , errorMsg)
        exit()        
 
    cursor.close()                              #close the cursor/buffer
    conn.close()                                #close the connection
    
#===============================================================================================
# Create table 
# receives: database name (connection string), table name, and a list of dictionaries
# returns:  None
#===============================================================================================
def createTable(db_name, table_name, list_dict):
    import re                                                   #import regex module

    sql  ='CREATE TABLE ' + table_name + '\n'          
    sql += '( \n'
    firstEntry= list_dict[0]                                    #take the first entry

    for (colname,colvalue) in firstEntry.items():               #for every name/value
        sql +='    ' + colname + '\t '                          #add the column name
        datatype = type(colvalue)                               #determine the colvalue data type
        if re.search('(date|_dt|dt_)',colname) \
        or re.match('^[0123]?\d-\w\w\w-\d\d',str(colvalue) ):   #if starts with dd-MON-yy
            sql += 'DATE, \n'
        elif datatype is int or datatype is float:              #if integer or float
            sql += 'NUMBER, \n'
        elif datatype is str:                                   #if string
            sql += 'VARCHAR(250), \n'
        else:
            sql += 'UNKNOWN DATATYPE, \n'                      

    sql = sql[0:len(sql)-3] + '\n'                             #get rid of last comma ,
    sql += ')'
#   print(sql)                                                 #print (debug)

    conn = connect(db_name)                                    #connect to database
    try:
        cursor = conn.cursor()                                 #create a cursor  
        cursor.execute(sql);                                   #execute the query
        conn.commit();                                         #commit to database

    except Exception as e:
        (errorNum, errorMsg) = e.args
        print("Could not create table -" , errorNum , errorMsg)
        exit()        
 
    cursor.close()                              #close the cursor/buffer
    conn.close()                                #close the connection
    
#===============================================================================================
# Insert data into Mysql table 
# receives: database name (connection string), table name, and a list of dictionaries
# returns:  number of rows inserted
# PS.       commit for every 100 rows
#===============================================================================================
def insert(db_name, table_name, list_dict):
    import re                                                  #import regex module

    count = 0

    conn = connect(db_name)                                    #connect to database
    try:
        cursor = conn.cursor()                                 #create a cursor  

        for dict in list_dict:                                 #loop thru list of dictionaries
            col_names  = dict.keys()                           #get the column names
            col_values = dict.values()                         #get the column values
            sql  = "INSERT INTO " + table_name                 
            sql += " (" + ", ".join(col_names) + ") "          #add the column names
            sql += "VALUES("
            for col_value in col_values:                       #for each column value
                if isinstance(col_value, str):                 #if the value is a string
                    sql += "'" + col_value + "', "             #add single quotes around it
                else:
                    sql += str(col_value) + ", "               #no quotes around the numerics
            sql  = sql[0:len(sql)-2]                           #get rid of the last comma
            sql += ')'
#           print(sql)                                         #print (debug) to see the insert stmts
            cursor.execute(sql);                               #execute the query
            count += 1
            if count%100 == 0:                                 #for every 100 inserts
                cursor.execute('COMMIT');                      #commit to database   

        conn.commit();                                         #final commit (same as above)

    except Exception as e:
        (errorNum, errorMsg) = e.args
        print("Could not execute insert -" , errorNum , errorMsg)
        exit()        
 
    cursor.close()                              #close the cursor/buffer
    conn.close()                                #close the connection
    
    return(count)
#===============================================================================================

#connect('demo')                                         #for test only
#result = query('demo','select *  from student')         #for test only
#print(result)                                           #for test only
#dropTable('demo','xxx')                                 #for test only
#createTable('demo','xxx',result)                        #for test only
#insert('demo','xxx',result)                             #for test only