##########################################################################################
# Functions to encrypt and decrypt strings
#   1. encrypt/decrypt:    Uses base64 for encryption
#   3. encrypt2/decrypt2:  Uses cryptography.fernet for encryption
#   2. encrypt3/decrypt3:  Uses custom encryption (using string.translate) 
# Also functions to
#   1. Add an encrypted connection string to   mysql or oracle keys file
#   2. Get and decrypt  connection string from mysql or oracle keys file
##########################################################################################

#=========================================================================================
# encrypt a string - using std base64 encryption, no key required
# receives: clear string
# returns:  encrypted string
#=========================================================================================
def encrypt(string):
    import base64
    byte_array       = bytearray(string,'utf-8')            #convert string to byte array
    encrypted_bytes  = base64.b64encode(byte_array)         #encrypt byte array
    encrypted_string = encrypted_bytes.decode()             #convert byte array to string
    return(encrypted_string)

#=========================================================================================
# decrypt a string - using std base64 decryption, no key required
# receives: encrypted string
# returns:  clear string
#=========================================================================================
def decrypt(encrypted_string):
    import base64
    encrypted_bytes = bytearray(encrypted_string,'utf-8')   #convert string to byte array
    byte_array      = base64.b64decode(encrypted_bytes)     #decrypt byte array
    string          = byte_array.decode()                   #convert byte array to string 
    return(string)

#=========================================================================================
# encrypt a string - using cryptography.fernet encryption, pswd required
# receives: clear string, pswd (must be 32 char long)
# returns:  encrypted string
#=========================================================================================
def encrypt2(str, pswd):
    import base64
    from cryptography.fernet import Fernet
    pswd   = bytes(pswd,'utf-8')                                 #convert pswd to bytes format
    key    = base64.urlsafe_b64encode(pswd)                      #convert the pswd to a key
#   key    = Fernet.generate_key()                               #or generate a key in bytes format
#   key    = b'this-is-a-strong-key--so-help-me-encryption='     #or use this key in bytes format  
    fernet = Fernet(key)
    encrypted_str = fernet.encrypt(str.encode())
    return encrypted_str

#=========================================================================================
# decrypt a string - using cryptography.fernet encryption, key required
# receives: encrypted string, pswd (must be 32 char long)
# returns:  clear string
#=========================================================================================
def decrypt2(str, pswd):
    import base64
    from cryptography.fernet import Fernet
    pswd   = bytes(pswd,'utf-8')                                 #convert pswd to bytes format
    key    = base64.urlsafe_b64encode(pswd)                      #convert the pswd to a key
    fernet = Fernet(key)
    decrypted_str = fernet.decrypt(str).decode()
    return decrypted_str

#=========================================================================================
# encrypt/decrypt a string - key required
# custom encryption/decryption using a user provided key (strong encryption)
# receives: string (clear or encrypted), key, direction=encrypt/decrypt
# returns:  string (encrypted or decrypted)
#=========================================================================================
def crypt(string, key, direction='encrypt'):

    uppercase   = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    lowercase   = "abcdefghijklmnopqrstuvwxyz"
    numbers     = "0123456789"
    punctuation = "~!@#$%^&*:;.',?+=/"
    base        = uppercase + lowercase + numbers + punctuation      #concatenate all above

    seed = 0
    for char in key:                               #loop thru each char of key
        seed += ord(char)                          #convert char to unicode value, and add to seed
    offset = seed % len(base)                      #offset is seed mod length of base
    if offset==0: offset=1                         #if no offset, make offset 1                       

    if direction.lower() == 'encrypt':
        source = base                              #clear text
        target = base[offset:] + base[:offset]     #shifted characters 
    else:
        source = base[offset:] + base[:offset]     #shifted characters
        target = base                              #clear text

    transTable = string.maketrans(source,target)   #create a translation dict from_char/to_char
    newString  = string.translate(transTable)      #translate (encrypt/decrypt)
    return(newString)

#=========================================================================================
# encrypt a sting - key required
# call the above crypt function (with direction 'encrypt')
# receives: clear string, key
# returns:  encrypted string
#=========================================================================================
def encrypt3(string, key):
    return crypt(string, key,'encrypt')

#=========================================================================================
# decrypt a sting - key required
# call the above crypt function (with direction 'decrypt')
# receives: encrypted string, key
# returns:  decrypted string
#=========================================================================================
def decrypt3(string, key):
    return crypt(string, key,'decrypt')
    

#=========================================================================================
# Add connection_string to mysql or oracle keys file
# Receives:  connection name (also used as key), connection string, db=mysql|oracle
# Writes:    connection_name=encrypted_connection_string 
#=========================================================================================
def add_connection_string(name, connection_string, db='mysql'):
    if db=='mysql':
        file = '/home/sultans/web/python/demo/etl/util/mysql_keys.txt' 
    else:
        file = '/home/sultans/web/python/demo/etl/util/oracle_keys.txt' 
    output = open(file,'a')                                 #append to file
#   encrypted_string = encrypt(connection_string)           #call encrypt  function with string
#   encrypted_string = encrypt2(connection_string, name)    #call encrypt2 function with string and name/key
    encrypted_string = encrypt3(connection_string, name)    #call encrypt3 function with string and name/key
    new_connection = name + '=' + encrypted_string + '\n'   #create name=encrypted string rec
    output.write(new_connection)                            #write the encrypted string                            
    output.close()
    return(new_connection)

#=========================================================================================
# Retrieve a connection string by name from mysql or oracle keys file
# Receives:  connection name (also used as key), db=mysql|oracle
# Returns:   connection string decrypted, ready to be used in Oracle connection
#=========================================================================================
def get_connection_string(name, db='mysql'):
    if db=='mysql':
        file = '/home/sultans/web/python/demo/etl/util/mysql_keys.txt' 
    else:
        file = '/home/sultans/web/python/demo/etl/util/oracle_keys.txt' 
    input = open(file,'r')                                        #open file for read
    lines = input.read()                                          #read entire file
    input.close()
    connection_strings = lines.split('\n')                        #split on newline char
    for line in connection_strings:                               #for each line
        (connection,encrypted_string) = line.split('=',1)         #split the rec on first '='
        if connection == name:                                    #if equal to the one requested
#           connection_string = decrypt(encrypted_string)         #call decrypt  function with string
#           connection_string = decrypt2(encrypted_string, name)  #call decrypt3 function with string and name/key
            connection_string = decrypt3(encrypted_string, name)  #call decrypt3 function with string and name/key
            return(connection_string)                             #return the decrypted string
    return(connection + ' Not found')

#=========================================================================================