#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>

#include <openssl/evp.h> // encryption
#include <openssl/aes.h> // for AES_BLOCK_SIZE
#include <tpl.h> // serialization




enum { ENCRYPT, DECRYPT };


int aes_init( unsigned char *key_data, int key_data_len,
		unsigned char *salt, EVP_CIPHER_CTX *ctx, int op ) {
	int i, rounds = 5;
	unsigned char key[32], iv[32];
	i = EVP_BytesToKey( EVP_aes_256_cbc(), EVP_sha256(),
		salt, key_data, key_data_len, rounds, key, iv );
	if (i != 32) return(-1);

	EVP_CIPHER_CTX_init(ctx);
	if (op == ENCRYPT) EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
	else EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);

	return(0); }


unsigned char *aes_encrypt( EVP_CIPHER_CTX *ctx,
		unsigned char *plaintext, int *len ) {
	int c_len = *len + AES_BLOCK_SIZE, f_len = 0;
	unsigned char *ciphertext = malloc(c_len);
	EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, NULL);
	EVP_EncryptUpdate(ctx, ciphertext, &c_len, plaintext, *len);
	EVP_EncryptFinal_ex(ctx, ciphertext+c_len, &f_len);
	*len = c_len + f_len;
	return(ciphertext); }

unsigned char *aes_decrypt( EVP_CIPHER_CTX *ctx,
		unsigned char *ciphertext, int *len ) {
	int p_len = *len, f_len = 0;
	unsigned char *plaintext = malloc(p_len);
	EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, NULL);
	EVP_DecryptUpdate(ctx, plaintext, &p_len, ciphertext, *len);
	EVP_DecryptFinal_ex(ctx, plaintext+p_len, &f_len);
	*len = p_len + f_len;
	return(plaintext); }



int bs = 256;
int key_slot_max = 8;


int main(const int argc, char *const *argv) {
	int usage() {
		fprintf( stderr, ( "Usage:\n"
			"  %s ( -h | --help )\n"
			"  %s FILE <PASS >CORE\n"
			"  %s -d FILE <PASS >NULL\n"
			"  %s -a<PASS> FILE <CORE >NULL\n"
			"Multi-key cryptocontainer.\n"
			"Intended purpose is to hold single string (CORE),"
			" encrypted by different passphrases (PASS).\n"
			"Each CORE+PASS occupies the slot in FILE (up to %d).\n"
			"On unlock, all slots are tried sequentially, first"
			" successfully decrypted CORE is returned.\n"
			"When adding new PASS, any other slot decrypted"
			" by it will be dropped, so no del-ops are necessary to"
			" replace CORE.\n" ),
			argv[0], argv[0], argv[0], argv[0], key_slot_max );
		return(1); };

	if ( argc < 2 || argc > 4 ||
			!strcmp(argv[1], "-h") || !strcmp(argv[1], "--help") ) {
		usage();
		return(0); }

	// Process arguments
	int i = 1, key_del = 0;
	char *key_set = NULL;
	if (!strcmp(argv[i], "-l")) i++;
	else {
		if (!strncmp(argv[i], "-d", 2)) { key_del = 1; i++; }
		else if (!strncmp(argv[i], "-a", 2)) key_set = argv[i++] + 2; }
	if (argc != (i+1)) return(usage());

	// Read stdin
	int j, key_len = 0;
	unsigned char *key = malloc(bs);
	while(1) {
		j = fread(key + key_len, 1, bs, stdin);
		key_len += j;
		if (j < bs) break;
		if (!realloc(key, key_len + bs)) return(3); }
	// If stdin is a key, strip final newline
	if (!key_set && key[key_len-1] == '\n') {
		if (!realloc(key, --key_len)) return(3); }


	EVP_CIPHER_CTX ctx;
	unsigned char salt[8];
	int core_len = 0, crypt_len = 0;
	unsigned char *core = NULL, *crypt;


	// Encrypt core, if adding key
	//  core in this case is read from stdin as key
	//  and key is parsed as key_set
	if (key_set) {
		// Add '-' to key beginning
		if (!realloc(key, key_len+1)) return(3);
		memmove(key+1, key, key_len++);
		key[0] = '-';

		// Encrypt and add salt to crypt beginning
		FILE *entropy = fopen("/dev/urandom", "r");
		if ( !fread(salt, 1, 8, entropy) ||
			aes_init(key_set, strlen(key_set), salt,
				&ctx, ENCRYPT) ) return(3);
		fclose(entropy);
		crypt = aes_encrypt(&ctx, key, &key_len);
		crypt_len = key_len;
		if (!realloc(crypt, crypt_len+8)) return(3);
		memmove(crypt+8, crypt, crypt_len);
		memcpy(crypt, salt, 8);

		// Set key to a correct value
		free(key); // drop plaintext core
		key = key_set;
		key_len = strlen(key); }


	tpl_bin blob;
	tpl_node *src = NULL;
	tpl_node *dst = tpl_map("A(B)", &blob);


	// Read tpl, if there is a FILE
	struct stat stat_void;
	if (!stat(argv[i], &stat_void)) {
		src = tpl_map("A(B)", &blob);
		tpl_load(src, TPL_FILE, argv[i]); }


	if (src) {
		// Unpack slots, dropping some, if requested
		tpl_bin slot[key_slot_max]; j = -1;
		while (tpl_unpack(src, 1) > 0) {
			if ( !memcpy(salt, blob.addr, 8) ||
				aes_init(key, key_len, salt,
					&ctx, DECRYPT) ) return(3);
			core_len = blob.sz - 8;
			core = aes_decrypt(&ctx, blob.addr+8, &core_len);
			if (key_set) {
				if (core[0] != '-') slot[++j] = blob; } // drop same-key slots
			else if (key_del) {
					if (core[0] != '-') slot[++j] = blob;
					else key_del = 2; }
			else {
				if (core[0] == '-') break;
				else core = NULL; } }

		// Re-pack remaining slots into dst
		if (key_set || key_del)
			while (j >= 0) {
				blob = slot[j--];
				tpl_pack(dst, 1); } }


	// Exit code
	j = 0;

	if (key_set) {
		// Append new key
		blob.addr = crypt;
		blob.sz = crypt_len + 8;
		tpl_pack(dst, 1); }

	else {
		// Check if dumping new data isn't necessary
		if (key_del == 1 || (!key_del && !core)) {
			fprintf(stderr, "No slot found for a given key\n");
			j = 1;
			goto end; }
		if (!key_del) {
			if (!fwrite(core+1, 1, core_len-1, stdout)) j = 3;
			goto end; } }

	tpl_dump(dst, TPL_FILE, argv[i]);


end:
	if (src) tpl_free(src);
	if (dst) tpl_free(dst);
	if (core) free(core);

	return(j); }

