package com.union.crypto.engines;

import com.union.crypto.BlockCipher;
import com.union.crypto.CipherParameters;
import com.union.crypto.DataLengthException;
import com.union.crypto.OutputLengthException;
import com.union.crypto.params.KeyParameter;

/**
 * SM4国密算法
 * 
 * @author longwx
 * @date 2015-11-02
 */
public class SM4Engine implements BlockCipher {

	protected static final int BLOCK_SIZE = 16;

	private int[] workingKey = null;

	@Override
	public void init(boolean encrypting, CipherParameters params)
			throws IllegalArgumentException {
		if (params instanceof KeyParameter) {
			if (((KeyParameter) params).getKey().length > 16) {
				throw new IllegalArgumentException(
						"SM4 key too long - should be 16 bytes");
			}

			workingKey = generateWorkingKey(encrypting,
					((KeyParameter) params).getKey());

			return;
		}

		throw new IllegalArgumentException(
				"invalid parameter passed to SM4 init - "
						+ params.getClass().getName());
	}

	@Override
	public String getAlgorithmName() {
		return "SM4";
	}

	@Override
	public int getBlockSize() {
		return BLOCK_SIZE;
	}

	@Override
	public int processBlock(byte[] in, int inOff, byte[] out, int outOff)
			throws DataLengthException, IllegalStateException {
		if (workingKey == null) {
			throw new IllegalStateException("SM4 engine not initialised");
		}

		if ((inOff + BLOCK_SIZE) > in.length) {
			throw new DataLengthException("input buffer too short");
		}

		if ((outOff + BLOCK_SIZE) > out.length) {
			throw new OutputLengthException("output buffer too short");
		}

		sm4Func(workingKey, in, inOff, out, outOff);

		return BLOCK_SIZE;
	}

	@Override
	public void reset() {
		// TODO Auto-generated method stub
	}
	
	/**----------------------------------------------------------------------**/
	
	// S盒
	private final byte[] Sbox = {
		(byte) 0xD6, (byte) 0x90, (byte) 0xE9, (byte) 0xFE, (byte) 0xCC, (byte) 0xE1, (byte) 0x3D, (byte) 0xB7, (byte) 0x16, (byte) 0xB6, (byte) 0x14, (byte) 0xC2, (byte) 0x28, (byte) 0xFB, (byte) 0x2C, (byte) 0x05,
		(byte) 0x2B, (byte) 0x67, (byte) 0x9A, (byte) 0x76, (byte) 0x2A, (byte) 0xBE, (byte) 0x04, (byte) 0xC3, (byte) 0xAA, (byte) 0x44, (byte) 0x13, (byte) 0x26, (byte) 0x49, (byte) 0x86, (byte) 0x06, (byte) 0x99,
		(byte) 0x9C, (byte) 0x42, (byte) 0x50, (byte) 0xF4, (byte) 0x91, (byte) 0xEF, (byte) 0x98, (byte) 0x7A, (byte) 0x33, (byte) 0x54, (byte) 0x0B, (byte) 0x43, (byte) 0xED, (byte) 0xCF, (byte) 0xAC, (byte) 0x62,
		(byte) 0xE4, (byte) 0xB3, (byte) 0x1C, (byte) 0xA9, (byte) 0xC9, (byte) 0x08, (byte) 0xE8, (byte) 0x95, (byte) 0x80, (byte) 0xDF, (byte) 0x94, (byte) 0xFA, (byte) 0x75, (byte) 0x8F, (byte) 0x3F, (byte) 0xA6,
		(byte) 0x47, (byte) 0x07, (byte) 0xA7, (byte) 0xFC, (byte) 0xF3, (byte) 0x73, (byte) 0x17, (byte) 0xBA, (byte) 0x83, (byte) 0x59, (byte) 0x3C, (byte) 0x19, (byte) 0xE6, (byte) 0x85, (byte) 0x4F, (byte) 0xA8,
		(byte) 0x68, (byte) 0x6B, (byte) 0x81, (byte) 0xB2, (byte) 0x71, (byte) 0x64, (byte) 0xDA, (byte) 0x8B, (byte) 0xF8, (byte) 0xEB, (byte) 0x0F, (byte) 0x4B, (byte) 0x70, (byte) 0x56, (byte) 0x9D, (byte) 0x35,
		(byte) 0x1E, (byte) 0x24, (byte) 0x0E, (byte) 0x5E, (byte) 0x63, (byte) 0x58, (byte) 0xD1, (byte) 0xA2, (byte) 0x25, (byte) 0x22, (byte) 0x7C, (byte) 0x3B, (byte) 0x01, (byte) 0x21, (byte) 0x78, (byte) 0x87,
		(byte) 0xD4, (byte) 0x00, (byte) 0x46, (byte) 0x57, (byte) 0x9F, (byte) 0xD3, (byte) 0x27, (byte) 0x52, (byte) 0x4C, (byte) 0x36, (byte) 0x02, (byte) 0xE7, (byte) 0xA0, (byte) 0xC4, (byte) 0xC8, (byte) 0x9E,
		(byte) 0xEA, (byte) 0xBF, (byte) 0x8A, (byte) 0xD2, (byte) 0x40, (byte) 0xC7, (byte) 0x38, (byte) 0xB5, (byte) 0xA3, (byte) 0xF7, (byte) 0xF2, (byte) 0xCE, (byte) 0xF9, (byte) 0x61, (byte) 0x15, (byte) 0xA1,
		(byte) 0xE0, (byte) 0xAE, (byte) 0x5D, (byte) 0xA4, (byte) 0x9B, (byte) 0x34, (byte) 0x1A, (byte) 0x55, (byte) 0xAD, (byte) 0x93, (byte) 0x32, (byte) 0x30, (byte) 0xF5, (byte) 0x8C, (byte) 0xB1, (byte) 0xE3,
		(byte) 0x1D, (byte) 0xF6, (byte) 0xE2, (byte) 0x2E, (byte) 0x82, (byte) 0x66, (byte) 0xCA, (byte) 0x60, (byte) 0xC0, (byte) 0x29, (byte) 0x23, (byte) 0xAB, (byte) 0x0D, (byte) 0x53, (byte) 0x4E, (byte) 0x6F,
		(byte) 0xD5, (byte) 0xDB, (byte) 0x37, (byte) 0x45, (byte) 0xDE, (byte) 0xFD, (byte) 0x8E, (byte) 0x2F, (byte) 0x03, (byte) 0xFF, (byte) 0x6A, (byte) 0x72, (byte) 0x6D, (byte) 0x6C, (byte) 0x5B, (byte) 0x51,
		(byte) 0x8D, (byte) 0x1B, (byte) 0xAF, (byte) 0x92, (byte) 0xBB, (byte) 0xDD, (byte) 0xBC, (byte) 0x7F, (byte) 0x11, (byte) 0xD9, (byte) 0x5C, (byte) 0x41, (byte) 0x1F, (byte) 0x10, (byte) 0x5A, (byte) 0xD8,
		(byte) 0x0A, (byte) 0xC1, (byte) 0x31, (byte) 0x88, (byte) 0xA5, (byte) 0xCD, (byte) 0x7B, (byte) 0xBD, (byte) 0x2D, (byte) 0x74, (byte) 0xD0, (byte) 0x12, (byte) 0xB8, (byte) 0xE5, (byte) 0xB4, (byte) 0xB0,
		(byte) 0x89, (byte) 0x69, (byte) 0x97, (byte) 0x4A, (byte) 0x0C, (byte) 0x96, (byte) 0x77, (byte) 0x7E, (byte) 0x65, (byte) 0xB9, (byte) 0xF1, (byte) 0x09, (byte) 0xC5, (byte) 0x6E, (byte) 0xC6, (byte) 0x84,
		(byte) 0x18, (byte) 0xF0, (byte) 0x7D, (byte) 0xEC, (byte) 0x3A, (byte) 0xDC, (byte) 0x4D, (byte) 0x20, (byte) 0x79, (byte) 0xEE, (byte) 0x5F, (byte) 0x3E, (byte) 0xD7, (byte) 0xCB, (byte) 0x39, (byte) 0x48
	};
	
	// 固定参数CK
	private final int[] CK = {
			0x00070E15, 0x1C232A31, 0x383F464D, 0x545B6269,
			0x70777E85, 0x8C939AA1, 0xA8AFB6BD, 0xC4CBD2D9,
			0xE0E7EEF5, 0xFC030A11, 0x181F262D, 0x343B4249,
			0x50575E65, 0x6C737A81, 0x888F969D, 0xA4ABB2B9,
			0xC0C7CED5, 0xDCE3EAF1, 0xF8FF060D, 0x141B2229,
			0x30373E45, 0x4C535A61, 0x686F767D, 0x848B9299,
			0xA0A7AEB5, 0xBCC3CAD1, 0xD8DFE6ED, 0xF4FB0209,
			0x10171E25, 0x2C333A41, 0x484F565D, 0x646B7279
	};
	
	// 系统参数FK
	private final int[] FK = {
			0xA3B1BAC6, 0x56AA3350, 0x677D9197, 0xB27022DC
	};
	
	/**
	 * 整型数据转换成字节数组数据
	 */
	private void intToBytes(int in, byte[] out, int offsetOut) {
		out[offsetOut] = (byte) (in >>> 24);
		out[offsetOut + 1] = (byte) (in >>> 16);
		out[offsetOut + 2] = (byte) (in >>> 8);
		out[offsetOut + 3] = (byte) (in);
	}
	
	/**
	 * 字节数据数据转换成整型数据
	 */
	private int bytesToInt(byte[] in, int offset) {
		return ((in[offset] & 0xFF) << 24) |
				((in[offset+1] & 0xFF) << 16) |
				((in[offset+2] & 0xFF) << 8) |
				(in[offset+3] & 0xFF);
	}
	
	/**
	 * 32比特循环左移n位
	 */
	private int leftShift(int in, int shiftBits) {
		return (in << shiftBits) | (in >>> (32 - shiftBits));
	}
	
	/**
	 * 线性变换L
	 * 输入: B; 字-32bit
	 * 输出: C; 字-32bit
	 * C = L(B) = B xor (B <<< 2) xor (B <<< 10) xor (B <<< 18) xor (B <<< 24)
	 */
	private int transformL(int in) {
		return in ^ leftShift(in, 2) ^ leftShift(in, 10)
				^ leftShift(in, 18) ^ leftShift(in, 24);
	}
	
	/**
	 * 线性变换L'
	 * 输入: B; 字-32bit
	 * 输出: C; 字-32bit
	 * C = L(B) = B xor (B <<< 13) xor (B <<< 23)
	 */
	private int transformLEX(int in) {
		return in ^ leftShift(in, 13) ^ leftShift(in, 23);
	}
	
	/**
	 * 非线性变换τ
	 * 输入: A = (a0, a1, a2, a3); 字节-8bit
	 * 输出: B = (b0, b1, b2, b3); 字节-8bit
	 * (b0, b1, b2, b3) = τ(A) = (Sbox(a0), Sbox(a1), Sbox(a2), Sbox(a3));
	 */
	private int transformT(int in) {
		return ((Sbox[(in >>> 24) & 0xFF] & 0xFF) << 24) | ((Sbox[(in >>> 16) & 0xFF] & 0xFF) << 16)
				| ((Sbox[(in >>> 8) & 0xFF] & 0xFF) << 8) | (Sbox[in & 0xFF] & 0xFF);
	}
	
	/**
	 * 合成置换T(T')
	 * T(.) = L(τ(.))
	 */
	private int compSwapT(int in, boolean flag) {
		if(flag) {
			return transformL(transformT(in));
		} else {
			return transformLEX(transformT(in));
		}
	}
	
	/**
	 * 轮函数F
	 * 输入: (X[0], X[1], X[2], X[3]); 字-32bit
	 * 输出: Y; 字-32bit
	 * Y = F(X[0], X[1], X[2], X[3], rk) = X[0] xor T(X[1] xor X[2] xor X[3] xor rk)
	 */
	private int roundF(int x0, int x1, int x2, int x3, int rk) {
		return x0 ^ compSwapT(x1 ^ x2 ^ x3 ^ rk, true);
	}
	
	protected int[] generateWorkingKey(boolean encrypting, byte[] key) {
		int[] rk = new int[32];  // 轮密钥
		// MK = (MK[0], MK[1], MK[2], MK[3])
		int[] MK = new int[4];
		for(int i = 0; i < 4; i++) {
			MK[i] = bytesToInt(key, i * 4);
		}
		// (K[0], K[1], K[2], K[3]) = (MK[0] xor FK[0], MK[1] xor FK[1], MK[2] xor FK[2], MK[3] xor FK[3])
		int[] K = new int[4];
		for(int i = 0; i < 4; i++) {
			K[i] = MK[i] ^ FK[i];
		}
		// rk[i] = K[i+4] = K[i] xor T'(K[i+1] xor K[i+2] xor K[i+3] xor CK[i])
//		for(int i = 0; i < 32; i += 4) {
//			rk[i] = K[0] = K[0] ^ compSwapT(K[1] ^ K[2] ^ K[3] ^ CK[i], false);
//			rk[i + 1] = K[1] = K[1] ^ compSwapT(K[2] ^ K[3] ^ K[0] ^ CK[i + 1], false);
//			rk[i + 2] = K[2] = K[2] ^ compSwapT(K[3] ^ K[0] ^ K[1] ^ CK[i + 2], false);
//			rk[i + 3] = K[3] = K[3] ^ compSwapT(K[0] ^ K[1] ^ K[2] ^ CK[i + 3], false);
//		}
		if(encrypting) {  // 加密(轮密钥使用顺序为:rk0, rk1, ... , rk31)
			for(int i = 0; i < 32; i += 4) {
				rk[i] = K[0] = K[0] ^ compSwapT(K[1] ^ K[2] ^ K[3] ^ CK[i], false);
				rk[i + 1] = K[1] = K[1] ^ compSwapT(K[2] ^ K[3] ^ K[0] ^ CK[i + 1], false);
				rk[i + 2] = K[2] = K[2] ^ compSwapT(K[3] ^ K[0] ^ K[1] ^ CK[i + 2], false);
				rk[i + 3] = K[3] = K[3] ^ compSwapT(K[0] ^ K[1] ^ K[2] ^ CK[i + 3], false);
			}
		} else {  // 解密(轮密钥使用顺序为:rk31, rk30, ... , rk0)
			for(int i = 31; i > 0; i -= 4) {
				rk[i] = K[0] = K[0] ^ compSwapT(K[1] ^ K[2] ^ K[3] ^ CK[31 - i], false);
				rk[i - 1] = K[1] = K[1] ^ compSwapT(K[2] ^ K[3] ^ K[0] ^ CK[32 - i], false);
				rk[i - 2] = K[2] = K[2] ^ compSwapT(K[3] ^ K[0] ^ K[1] ^ CK[33 - i], false);
				rk[i - 3] = K[3] = K[3] ^ compSwapT(K[0] ^ K[1] ^ K[2] ^ CK[34 - i], false);
			}
		}
		return rk;
	}
	
	protected void sm4Func(int[] wKey, byte[] in, int inOff, byte[] out, int outOff) {
		// 输入(X[0], X[1], X[2], X[3])
		int[] X = new int[4];
		for(int i = 0; i < 4; i++) {
			X[i] = bytesToInt(in, inOff + i * 4);
		}
		// X[i+4] = F(X[i], X[i+1], X[i+2], X[i+3], rk[i])
		for(int i = 0; i < 32; i += 4) {
			X[0] = roundF(X[0], X[1], X[2], X[3], wKey[i]);
			X[1] = roundF(X[1], X[2], X[3], X[0], wKey[i + 1]);
			X[2] = roundF(X[2], X[3], X[0], X[1], wKey[i + 2]);
			X[3] = roundF(X[3], X[0], X[1], X[2], wKey[i + 3]);
		}
		// 输出(Y[0], Y[1], Y[2], Y[3]) = R(X[32], X[33], X[34], X[35]) = (X[35], X[34], X[33], X[32])
		for(int i = 3; i >= 0; i--) {
			intToBytes(X[i], out, outOff + (3 - i) * 4);
		}
	}

}
