#####################################################################################
# Database access functions
# DB connect, insert/update/delete, select functions
#####################################################################################
import os
import pymysql   as mydb                #Mysql 3x database driver
import cx_Oracle as oradb               #Oracle database driver
import config

host  ='localhost'
port1 ='3306'                           #Mysql
port2 ='1521'                           #Oracle
conn  = None

#====================================================================================
# connet(): Connect to a database
#     args: user, pswd, DBname, DBtype
#   return: connection object, or error string
#====================================================================================
def connect(user, pswd, DBname, DBtype='mysql'):
    global conn

    if DBtype=='mysql':
        try:
            conn = mydb.connect(host=host,user=user,password=pswd,database=DBname)
        except mydb.Error as e:
            errorNum = e.args[0]
            errorMsg = e.args[1]
            conn = F'Connection Error - {errorNum} {errorMsg}'

    if DBtype=='oracle':
        os.environ['ORACLE_HOME']='/u01/app/oracle/product/12.2.0/dbhome_1'
        try:
            conn_str = user+'/'+pswd+'@'+host+':'+ port2 +'/'+DBname;       #connection string (user/pswd@host:port/db)
            conn = oradb.connect(conn_str)    
        except oradb.DatabaseError as e:
            conn = F'Connection Error - {e}'
            
    if config.RUN_MODE=='test': print('CONN=>',conn,'<br>')    #for debugging
    return conn
    
#====================================================================================
# select(): select from database
#     args: sql
#   return: list of dictionaries, or error string
#====================================================================================
def select(sql):
    error=''
    if config.RUN_MODE=='test': print('SQL=>',sql,'<br>') 	#for debugging
    try:
        cursor = conn.cursor()                  #create a cursor
        cursor.execute(sql);                    #execute the sql
        metadata = cursor.description           #get column information
        data     = cursor.fetchall()            #get all the rows at once
        cursor.close()                          #close the cursor/buffer
    except mydb.Error as e:
        (errorNum, errorMsg) = e.args
        error = F'Database Error - {errorNum} {errorMsg}'
    except oradb.DatabaseError as e:
        error = F'Database Error - {e}'

    if error:
        result = error
    else:
        result = []
        for row in data:                            #loop thru all returned rows
            i=0
            dict = {}
            for col_value in row:                   #loop thru all columns for each row
                col_info = metadata[i]              #description is info about each column 
                col_name = col_info[0]              #1st element is column name
                dict[col_name] = col_value          #build dictionary of column name/value
                i+=1
            result.append(dict)                     #append the dictionary to the list

    if config.RUN_MODE=='test': print('RESULT=>',result,'<br>')    #for debugging
    return result

#====================================================================================
# update(): update into database
#     args: sql
#   return: number of affected rows or error string
#====================================================================================
def update(sql):
    if config.RUN_MODE=='test': print('SQL=>',sql,'<br>') 	#for debugging
    try:
        cursor = conn.cursor()                  #create a cursor
        cursor.execute(sql);                    #execute the sql
        cursor.close()                          #close the cursor/buffer
        conn.commit()
        result = cursor.rowcount
    except mydb.Error as e:
        (errorNum, errorMsg) = e.args
        result = F'Database Error - {errorNum} {errorMsg}'
    except oradb.DatabaseError as e:
        result = F'Database Error - {e}'

    if config.RUN_MODE=='test': print('RESULT=>',result,'<br>')    #for debugging
    return result
   
#====================================================================================
# insert(): insert into database.
#     args: sql
#   return: number of affected rows or error string
#====================================================================================
def insert(sql):
    numRows = update(sql)
    return numRows

#====================================================================================
# delete(): delete from database
#     args: sql
#   return: nnumber of affected rows or error string
#====================================================================================
def delete(sql):
    numRows = update(sql)
    return numRows

#====================================================================================
# main - used for testing only - all code commented out
#====================================================================================
"""
conn = connect('demo','demo','demo','mysql')
if type(conn)==str: print(conn); exit() 

result = insert("insert into output values('hello world')")
if type(result)==str: print(result); exit()
print(result)

result = select('select * from student')
if type(result)==str: print(result); exit()
print(result)
"""