#include <mycrypt.h>
#include <netinet/in.h>
#include <sys/socket.h>
/* this will just contain the function to authenicate */

/* return 1 if successful, 0 if not */
int auth_server(int socketfd, unsigned char *secret, unsigned long inlen);
int auth_client(int socketfd, unsigned char *secret, unsigned long inlen);

void register_algs() {
	register_cipher(&blowfish_desc);
    register_prng(&sprng_desc);
    register_hash(&sha1_desc);
    register_prng(&yarrow_desc);
}

main(int argc,char *argv[]) {
	if (argc == 3) {
		client(atoi(argv[2]),argv[1]);
	} else if (argc == 2) {
		server(atoi(argv[1]));
	} else {
		printf("%s [host] <port>\n",argv[0]);
	}
}
	
/* client */
client(int port,char *host){
	int socketfd,serverfd;
	unsigned char secret[MAXBLOCKSIZE];
	unsigned long len;
	struct sockaddr_in clientaddr;
	
	register_algs();
	memset(secret,0,MAXBLOCKSIZE);
	strcpy(secret,"5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8");
	
	if (!port) {
		port = 8974;
	}
	/* We need to establish the connection first */
	if ((socketfd = socket(AF_INET,SOCK_STREAM,0)) == -1) {
		printf("socket error\n");
		exit(1);
	}
	bzero(&clientaddr, sizeof(clientaddr));
    clientaddr.sin_family = AF_INET;
    clientaddr.sin_port = htons(port);
    clientaddr.sin_addr.s_addr = inet_addr("127.0.0.1");
	//inet_aton("127.0.0.1",clientaddr.sin_addr.s_addr);
	printf("connecting\n");
	if (connect(socketfd, (struct sockaddr *)&clientaddr,sizeof(clientaddr)) == -1) {
		printf("couldn't connect\n");
		exit(-1);
	}
	if (auth_client(socketfd,secret,sizeof(secret)) == 1) {
		printf("OK\n");
	} else {
		printf("failed\n");
		exit(-1);
	}
	
}
/*server*/
server(int port){
	int socketfd,serverfd;
	unsigned char secret[MAXBLOCKSIZE];
	unsigned long len;
	struct sockaddr_in serveraddr;
	
	register_algs();
	memset(secret,0,MAXBLOCKSIZE);
	strcpy(secret,"5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8");
	
	if (!port) {
		port = 8974;
	}
	/* We need to establish the connection first */
	if ((serverfd = socket(AF_INET,SOCK_STREAM,0)) == -1) {
		printf("socket error\n");
		exit(1);
	}
	bzero(&serveraddr, sizeof(serveraddr));
    serveraddr.sin_family = AF_INET;
    serveraddr.sin_port = htons(port);
    serveraddr.sin_addr.s_addr = htonl(INADDR_ANY);
	if (bind(serverfd, (struct socketaddr *)&serveraddr, sizeof(serveraddr)) == -1) {
		printf("bind error\n");
		exit(-1);
	}
	
	if (listen(serverfd,port) == -1) {
		printf("listen error\n");
		exit(-1);
	}
	printf("Listening for connection\n");
	socketfd = accept(serverfd, (struct socketaddr *)NULL,NULL);
	if (auth_server(socketfd,secret,sizeof(secret)) == 1) {
		printf("OK\n");
	} else {
		printf("failed\n");
		exit(-1);
	}
	
}
/* Don't forget to hash the password/etc before calling this.
   You want the same secret(hash) as the server */
   
int auth_client(int socketfd, unsigned char *secret, unsigned long inlen) {
	unsigned char mysecret[inlen];
	unsigned char nonce[128];
	unsigned char recv_buf[MAXBLOCKSIZE];
	unsigned char challenge[MAXBLOCKSIZE];
	unsigned char response[MAXBLOCKSIZE];
	unsigned char noncesec[MAXBLOCKSIZE+128];
	unsigned long keylen;
	int x,y,z,errno;
	prng_state prng;
	hash_state hash;
	hmac_state hmac;

	/* I hate you, I hate you, I hate you */
	memset(response,0,MAXBLOCKSIZE);
	memset(mysecret,0,inlen);
	memset(nonce,0,128);
	memset(recv_buf,0,MAXBLOCKSIZE);
	memset(challenge,0,MAXBLOCKSIZE);
	
	if (socketfd == -1) {
		printf("socket error\n");
		exit(-1);
	}
	if (inlen > MAXBLOCKSIZE) {
		printf("exceeded maxblocksize\n");
		exit(-1);
	}
	/* read in the nonce into recv_buf */
	if (read(socketfd,recv_buf,sizeof(recv_buf)) == -1) {
		printf("read error\n");
		exit(-1);
	}
	printmyhex(recv_buf,sizeof(recv_buf));
	/* use the secret as the key in the hmac stuff, then run the nonce
	   thought it. Yes. We get the bit we send to the server in hmac_done */
	if ((errno = hmac_init(&hmac,find_hash("sha1"),secret,inlen)) != CRYPT_OK) {
		printf("hmacinit error: %s\n",error_to_string(errno));
	}
	if ((errno = hmac_process(&hmac,recv_buf,sizeof(recv_buf))) != CRYPT_OK) {
		printf("hmacproc error: %s\n",error_to_string(errno));
	}
	hmac_done(&hmac,response);
	if (send(socketfd,response,sizeof(response),NULL) == -1) {
		printf("send error\n");
		exit(-1);
	}
	return(1);
}

int auth_server(int socketfd, unsigned char *secret,unsigned long inlen) {
	unsigned char mysecret[inlen];
	unsigned char nonce[128];
	unsigned char recv_buf[MAXBLOCKSIZE];
	unsigned char challenge[MAXBLOCKSIZE];
	int x,y,z,errno;
	prng_state prng;
	hash_state hash;
	hmac_state hmac;
	/* I hate you, I hate you, I hate you */
	memset(mysecret,0,inlen);
	memset(nonce,0,128);
	memset(recv_buf,0,MAXBLOCKSIZE);
	memset(challenge,0,MAXBLOCKSIZE);
	
	
	if (socketfd == -1) {
		printf("socket error\n");
		exit(-1);
	}
	if (inlen > MAXBLOCKSIZE) {
		printf("exceeded maxblocksize\n");
		exit(-1);
	}
	strncpy(mysecret,secret,inlen);
	/* fix later */
	x = rng_get_bytes(nonce,sizeof(nonce),NULL);
	
	if (send(socketfd,nonce,sizeof(nonce),NULL) == -1) {
		printf("send error\n");
		exit(-1);
	}
	/* don't forget to flush/clearn the recv buffer incase they client
	   sends bad stuff */
	if (read(socketfd,recv_buf,sizeof(recv_buf)) == -1) {
		printf("read error\n");
		exit(-1);
	}
	/* Use the shared secret and hmac the nonce, then compared what the client
	   sent us. If it matches they know the secret */
	if ((errno = hmac_init(&hmac,find_hash("sha1"),mysecret,sizeof(mysecret))) != CRYPT_OK) {
		printf("hmacinit error: %s\n",error_to_string(errno));
	}
	if ((errno = hmac_process(&hmac,nonce,sizeof(nonce))) != CRYPT_OK) {
		printf("hmacproc error: %s\n",error_to_string(errno));
	}
	hmac_done(&hmac,challenge);
	printmyhex(challenge,sizeof(challenge));
	printmyhex(recv_buf,sizeof(recv_buf));
	if (!strcmp(challenge,recv_buf)) {
		printf("Authenticated OK!\n");
		return(1);
	} else {
		printf("failed\n");
		return(-1);
	}
}	
printmyhex(unsigned char *array,int asize) {
	int x;
	for (x = 0;x<asize;x++) {
		printf("%x",array[x]);
	}
	printf("\n");
}
