import time
import struct
from enum import Enum
from typing import List, Dict, Optional
from .register import Register, Type

class Command(Enum):
	ERROR = ord('E')
	PING = ord('p')
	WRITE = ord('w')
	READ = ord('r')
	READMANY = ord('R')
	NAME = ord('n')
	FIRMWARE = ord('f')
	FIRMWARE_INFO = ord('F')

def hex(val: int) -> int:
	if val < 0x0A:
		return 0x30 + val
	elif val <= 0x0F:
		return 0x41 - 0x0A + val
	else:
		return None

class Packet:
	command: Command
	address: int = 0
	data: bytearray = bytearray()
	
	def __init__(self, cmd: Command) -> None:
		self.command = cmd
		self.data = bytearray()
	
	def value_to_bytes(self, value, length: int = 0) -> bytes:
		if type(value) is str:
			return bytes(value[:length], 'utf-8')
		elif type(value) is int:
			return value.to_bytes(length, byteorder='big', signed = (value < 0))
		elif type(value) is float:
			return value.to_bytes(length, byteorder='big')
		else:
			raise ValueError(f"Unsupported value type: {type(value)}")
	
	def encode(self, arr: bytes) -> bytes:
		result = bytearray()
		for i in arr:
			if i in [ord('<'), ord('|'), ord('>'), ord('%')]:
				result += bytes([ord('%'), hex(i//16), hex(i%16)])
			else:
				result.append(i)
		return bytes(result)
	
	def add_data(self, value, length: int = 0):
		self.data += self.value_to_bytes(value, length)
	
	def to_bytes(self) -> bytes:
		result = bytearray([ord('<')])
		data_buf = bytes([self.command.value]) + self.value_to_bytes(self.address, 4) + self.data
		result += self.encode(data_buf)
		result.append(ord('|'))
		result += self.encode(bytes([len(data_buf), (255+sum(data_buf)) % 256]))
		result.append(ord('>'))
		return bytes(result)
	
class RequestRead(Packet):
	def __init__(self, addr: int, count: int) -> None:
		super().__init__(Command.READMANY)
		if count <= 0:
			return None
		if count > 255:
			count = 255
		self.address = addr
		self.add_data(count, 1)

class RequestName(Packet):
	def __init__(self, addr: int) -> None:
		super().__init__(Command.NAME)
		self.address = addr

class RequestWrite(Packet):
	def __init__(self, addr: int, type: Type, data) -> None:
		super().__init__(Command.WRITE)
		self.type = type
		self.address = addr
		self.add_data(type.value, 1)
		if type == Type.BYTE:
			self.add_data(int(data), 1)
		elif type == Type.WORD:
			self.add_data(int(data), 2)
		elif type == Type.DWORD:
			self.add_data(int(data), 4)
		elif type == Type.INT:
			self.add_data(int(data), 4)
		elif type == Type.FLOAT:
			self.add_data(float(data), 4)
		elif type == Type.STRING:
			str_data = str(data) + '\x00'
			self.add_data(str_data, len(str_data))

class Response:
	valid: bool = False
	command: Command
	start_address: int
	data: bytes
	index: int
	error: int
	registers: List[Register] = list()
	def __init__(self, buf: bytes) -> None:
		self.valid = False
		self.command = Command.ERROR
		self.start_address = None
		self.data = buf
		self.index = 0
		self.registers: List[Register] = list()
		packet = self.decode(buf)
		if (packet is not None) and (len(packet) > 0):
			self.command = Command(buf[0])
			self.data = packet[1:]
			if self.command == Command.READMANY:
				self.decode_read()
			elif self.command == Command.NAME:
				self.decode_name()
			elif self.command == Command.WRITE:
				self.decode_write()
			elif self.command == Command.ERROR:
				code = int(packet[1:].decode(), 16)
				if code > 0x7FFF:
					code -= 0x10000
				self.error = code
				self.valid = True
		else:
			valid = False
	
	def decode(self, arr:bytes) -> bytes | None:
		result = bytearray()
		delimiter = None
		i = 0
		while i < len(arr):
			ch = arr[i]
			if ch == ord('|'):
				delimiter = len(result)
			elif ch == ord('%'):
				if len(arr) < i + 2:
					return None
				ch = arr[i+1]
				if (ch >= ord('0')) and (ch <= ord('9')):
					high = ch - ord('0')
				elif (ch >= ord('A')) and (ch <= ord('F')):
					high = ch - ord('A') + 10
				else:
					return None
				ch = arr[i+2]
				if (ch >= ord('0')) and (ch <= ord('9')):
					low = ch - ord('0')
				elif (ch >= ord('A')) and (ch <= ord('F')):
					low = ch - ord('A') + 10
				else:
					return None
				result.append(high*16 + low)
				i += 2
			else:
				result.append(ch)
			i += 1
		if (delimiter is None) and (len(result) == 5) and (result[0] == Command.ERROR.value):
			return result
		if delimiter != len(result) - 2:
			return None
		size = result[-2]
		if size != delimiter:
			return None
		crc = result[-1]
		result = bytes(result[:delimiter])
		if crc == (255+sum(result)) % 256:
			return result
		return None
	
	def read_value(self, type: Type):
		if type == Type.BYTE:
			value = self.data[self.index]
			self.index += 1
			return value
		elif type == Type.WORD:
			value = self.data[self.index:self.index+2]
			self.index += 2
			return int.from_bytes(value, byteorder='big', signed=False)
		elif type == Type.DWORD:
			value = self.data[self.index:self.index+4]
			self.index += 4
			return int.from_bytes(value, byteorder='big', signed=False)
		elif type == Type.INT:
			value = self.data[self.index:self.index+4]
			self.index += 4
			return int.from_bytes(value, byteorder='big', signed=True)
		elif type == Type.FLOAT:
			value = self.data[self.index:self.index+4]
			self.index += 4
			return struct.unpack('>f', value)[0]
		elif type == Type.STRING:
			value = self.data[self.index:].split(b'\x00')[0]
			self.index += len(value) + 1
			return str(value, encoding='utf-8')
		else:
			return None
	
	def decode_read(self):
		addr = self.read_value(Type.DWORD)
		self.start_address = addr
		while self.index < len(self.data):
			reg = Register(addr, Type(self.read_value(Type.BYTE)))
			reg.updated = time.time()
			reg.value = self.read_value(reg.type)
			self.registers.append(reg)
			addr += 1
		self.valid = True
	
	def decode_name(self):
		addr = self.read_value(Type.DWORD)
		self.start_address = addr
		value = self.read_value(Type.STRING)
		if (addr is not None) and (value is not None):
			reg = Register(addr, Type.EMPTY)
			reg.name = value
			self.registers.append(reg)
			self.valid = True

	def decode_write(self):
		addr = self.read_value(Type.DWORD)
		self.start_address = addr
		if (addr is not None):
			self.valid = True
