/*
    libmaus2
    Copyright (C) 2009-2014 German Tischler
    Copyright (C) 2011-2014 Genome Research Limited

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
#include <libmaus2/digest/SHA2_256_sse4.hpp>
#include <libmaus2/digest/sha256.h>
#include <libmaus2/rank/BSwapBase.hpp>
#include <algorithm>

#if defined(LIBMAUS2_HAVE_x86_64)
#include <emmintrin.h>
#endif

libmaus2::digest::SHA2_256_sse4::SHA2_256_sse4()
: block(2*(1ull<<libmaus2::digest::SHA2_256_sse4::blockshift),false),
  digestw(base_type::digestlength / sizeof(uint32_t),false),
  digestinit(base_type::digestlength / sizeof(uint32_t),false),
  index(0), blockcnt(0)
{
	#if ! ( defined(LIBMAUS2_USE_ASSEMBLY) &&  defined(LIBMAUS2_HAVE_x86_64) && defined(LIBMAUS2_HAVE_i386) && defined(LIBMAUS2_HAVE_SHA2_ASSEMBLY) )
	libmaus2::exception::LibMausException lme;
	lme.getStream() << "SHA2_256_sse4(): code has not been compiled into libmaus" << std::endl;
	lme.finish();
	throw lme;
	#endif

	#if defined(LIBMAUS2_USE_ASSEMBLY) && defined(LIBMAUS2_HAVE_i386)
	if ( !libmaus2::util::I386CacheLineSize::hasSSE41() )
	#else
	if ( true )
	#endif
	{
		libmaus2::exception::LibMausException lme;
		lme.getStream() << "SHA2_256_sse4(): machine does not support SSE4" << std::endl;
		lme.finish();
		throw lme;
	}

	// initial state for sha256
	static uint32_t const digest[8] =
	{
		0x6a09e667ul, 0xbb67ae85ul, 0x3c6ef372ul, 0xa54ff53aul,
		0x510e527ful, 0x9b05688cul, 0x1f83d9abul, 0x5be0cd19ul
	};

	for ( unsigned int i = 0; i < 8; ++i )
		digestinit[i] = digest[i];
}
libmaus2::digest::SHA2_256_sse4::~SHA2_256_sse4()
{

}

void libmaus2::digest::SHA2_256_sse4::init()
{
	#if defined(LIBMAUS2_HAVE_x86_64)
	index = 0;
	blockcnt = 0;

	__m128i * po = reinterpret_cast<__m128i *>(&digestw[0]);
	__m128i * pi = reinterpret_cast<__m128i *>(&digestinit[0]);

	// copy 128 bit words (SSE2 instructions)
	__m128i ra = _mm_load_si128(pi++);
	_mm_store_si128(po++,ra);
	ra = _mm_load_si128(pi++);
	_mm_store_si128(po++,ra);
	#endif
}
void libmaus2::digest::SHA2_256_sse4::update(
	#if defined(LIBMAUS2_HAVE_x86_64)
	uint8_t const * t,
	size_t l
	#else
	uint8_t const *,
	size_t
	#endif
)
{
	#if defined(LIBMAUS2_HAVE_x86_64)
	// something already in the buffer?
	if ( index )
	{
		uint64_t const tocopy = std::min(static_cast<uint64_t>(l),static_cast<uint64_t>(static_cast<size_t>(1ull<<base_type::blockshift)-index));
		std::copy(t,t+tocopy,&block[index]);
		index += tocopy;
		t += tocopy;
		l -= tocopy;

		if ( index == (1ull<<base_type::blockshift) )
		{
			// block is complete, handle it
			sha256_sse4(&block[0],&digestw[0],1);

			//
			blockcnt += 1;
			index = 0;
		}
		else
		{
			// done, block is not complete and there is no more data
			return;
		}
	}

	uint64_t const fullblocks = (l >> base_type::blockshift);

	// handle fullblocks blocks without copying them
	sha256_sse4(t,&digestw[0],fullblocks);

	blockcnt += fullblocks;
	t += fullblocks << base_type::blockshift;
	l -= fullblocks << base_type::blockshift;

	std::copy(t,t+l,&block[index]);
	index += l;
	#endif
}
void libmaus2::digest::SHA2_256_sse4::digest(
	#if defined(LIBMAUS2_HAVE_x86_64)
	uint8_t * digest
	#else
	uint8_t *
	#endif
)
{
	#if defined(LIBMAUS2_HAVE_x86_64)
	uint64_t const numbytes = (1ull<<base_type::blockshift) * blockcnt + index;
	uint64_t const numbits = numbytes << 3;

	// write start of padding
	block[index++] = 0x80;

	uint64_t const blockspace = (1ull<<base_type::blockshift)-index;

	if ( blockspace >= 8 )
	{
		// not multiple of 2?
		if ( index & 1 )
		{
			block[index] = 0;
			index += 1;
		}
		// not multiple of 4?
		if ( index & 2 )
		{
			*(reinterpret_cast<uint16_t *>(&block[index])) = 0;
			index += 2;
		}
		// not multiple of 8?
		if ( index & 4 )
		{
			*(reinterpret_cast<uint32_t *>(&block[index])) = 0;
			index += 4;
		}

		uint64_t * p = (reinterpret_cast<uint64_t *>(&block[index]));
		uint64_t * const pe = p + (((1ull<<base_type::blockshift)-index)/8-1);

		// use 64 bit = 8 byte words
		while ( p != pe )
			*(p++) = 0;

		// uint8_t * pp = reinterpret_cast<uint8_t *>(p);

		*p = libmaus2::rank::BSwapBase::bswap8(numbits);

		sha256_sse4(&block[0],&digestw[0],1);
	}
	else
	{
		// not multiple of 2?
		if ( index & 1 )
		{
			block[index] = 0;
			index += 1;
		}
		// not multiple of 4?
		if ( index & 2 )
		{
			*(reinterpret_cast<uint16_t *>(&block[index])) = 0;
			index += 2;
		}
		// not multiple of 8?
		if ( index & 4 )
		{
			*(reinterpret_cast<uint32_t *>(&block[index])) = 0;
			index += 4;
		}
		// not multiple of 16?
		if ( index & 8 )
		{
			*(reinterpret_cast<uint64_t *>(&block[index])) = 0;
			index += 8;
		}

		// rest of words in first block + all but one word in second block
		uint64_t restwords = (((1ull<<(base_type::blockshift+1))-index) >> 3) - 1;

		// erase using 128 bit words
		__m128i * p128 = reinterpret_cast<__m128i *>(&block[index]);
		__m128i * p128e = p128 + (restwords>>1);
		__m128i z128 = _mm_setzero_si128();

		while ( p128 != p128e )
			_mm_store_si128(p128++,z128);

		// erase another 64 bit word
		uint64_t * p = (reinterpret_cast<uint64_t *>(p128)); *p++ = 0;

		*p = libmaus2::rank::BSwapBase::bswap8(numbits);

		sha256_sse4(&block[0],&digestw[0],2);
	}

	uint32_t * digest32 = reinterpret_cast<uint32_t *>(&digest[0]);
	uint32_t * digest32e = digest32 + (base_type::digestlength / sizeof(uint32_t));
	uint32_t * digesti = &digestw[0];

	while ( digest32 != digest32e )
		*(digest32++) = libmaus2::rank::BSwapBase::bswap4(*(digesti++));
	#endif
}
void libmaus2::digest::SHA2_256_sse4::copyFrom(
	#if defined(LIBMAUS2_HAVE_x86_64)
	SHA2_256_sse4 const & O
	#else
	SHA2_256_sse4 const &
	#endif
)
{
	#if defined(LIBMAUS2_HAVE_x86_64)
	// blocksize is 64 = 4 * 16
	__m128i reg;
	__m128i const * blockin  = reinterpret_cast<__m128i const *>(&O.block[0]);
	__m128i       * blockout = reinterpret_cast<__m128i       *>(&  block[0]);

	reg = _mm_load_si128(blockin++);
	_mm_store_si128(blockout++,reg);
	reg = _mm_load_si128(blockin++);
	_mm_store_si128(blockout++,reg);
	reg = _mm_load_si128(blockin++);
	_mm_store_si128(blockout++,reg);
	reg = _mm_load_si128(blockin++);
	_mm_store_si128(blockout++,reg);

	// digest length is 32 = 2 * 16
	__m128i const * digestin  = reinterpret_cast<__m128i const *>(&O.digestw[0]);
	__m128i       * digestout = reinterpret_cast<__m128i       *>(&  digestw[0]);

	reg = _mm_load_si128(digestin++);
	_mm_store_si128(digestout++,reg);
	reg = _mm_load_si128(digestin++);
	_mm_store_si128(digestout++,reg);

	blockcnt = O.blockcnt;
	index = O.index;
	#endif
}

void libmaus2::digest::SHA2_256_sse4::vinit() { init(); }
void libmaus2::digest::SHA2_256_sse4::vupdate(uint8_t const * u, size_t l) { update(u,l); }
