package tools

import (
	"errors"
	"fmt"
	"math/big"
)

func NewPoint(x, y *big.Int, curve *Curve) *Point {
	return &Point{
		curve: curve,
		x:     intCopy(x),
		y:     intCopy(y),
	}
}

type Point struct {
	x     *big.Int
	y     *big.Int
	curve *Curve
}

func (s *Point) String() string {
	return fmt.Sprintf("(%v, %v)", s.x, s.y)
}

func (s *Point) PointNull() *Point {
	return &Point{
		curve: s.curve,
	}
}

// Проверка, валидна ли точка.
// Точка считается валидной, если определены x, y и объект кривой
func (s *Point) IsValid() (err error) {
	if s.x == nil {
		err = errors.New("координата [ x ] точки не определена")
	} else if s.y == nil {
		err = errors.New("координата [ y ] точки не определена")
	} else if s.curve == nil {
		err = errors.New("кривая точки не определена")
	} else {
		err = s.curve.IsValidP()
	}
	return
}

func (s *Point) InitXY(x, y *big.Int) *Point {
	return &Point{
		x:     intCopy(x),
		y:     intCopy(y),
		curve: s.curve,
	}
}

// Копирование точки
func (s *Point) Copy() *Point {
	res := &Point{
		curve: s.curve,
	}
	if s.x != nil {
		res.x = new(big.Int).Set(s.x)
	}
	if s.y != nil {
		res.y = new(big.Int).Set(s.y)
	}
	return res
}

func (s *Point) ShowHex() string {
	return fmt.Sprintf(
		"0x%x, 0x%x",
		mustBigInt(s.x),
		mustBigInt(s.y),
	)
}

func (s *Point) Show() string {
	return fmt.Sprintf(
		"%s, %s",
		mustBigInt(s.x),
		mustBigInt(s.y),
	)
}

func (s *Point) Coords() (x, y *big.Int) {
	if s.x != nil {
		x = new(big.Int).Set(s.x)
	}
	if s.y != nil {
		y = new(big.Int).Set(s.y)
	}
	return
}

// Проверка принадлежности точки кривой.
// Точка должна удовлетворять уровнению (y * y - x * x * x - c.a * x - c.b) % c.p == 0
func (s *Point) IsInCurve() error {
	if err := s.IsValid(); err != nil {
		return err
	}
	if err := s.curve.IsValidP(); err != nil {
		return err
	}
	x, y := s.Coords()
	c := s.curve
	y2 := Exp64(y, 2)
	x3 := Exp64(x, 3)
	cax := Mul(c.a, x)
	res := Sub(Sub(Sub(y2, x3), cax), c.b)
	res = Rem(res, c.p)
	if res.Cmp(big.NewInt(0)) != 0 {
		return fmt.Errorf("точка [ %s ] не пренадлежит кривой", s.Show())
	}
	return nil
}

// Вычисление наклона прямой, проходящей через 2 точки эллиптической кривой
func (s *Point) GetIncline(pt *Point) (m *big.Int, err error) {
	if err = s.IsValid(); err != nil {
		return
	}
	if err = pt.IsValid(); err != nil {
		return
	}
	m, cur := intCopy(intZero), s.curve
	x1, y1 := s.Coords()
	x2, y2 := pt.Coords()
	var iMod *big.Int
	if Cmp(x1, x2) { // !!! (points compare)
		// m = (3 * x1 * x1 + cur.a) * cur.inverseMod(2 * y1, cur.p)
		if iMod, err = cur.InverseMod(Mul(big.NewInt(2), y1), cur.p); err != nil {
			return
		}
		m = Mul(
			Add(
				Mul(big.NewInt(3), Exp64(x1, 2)),
				cur.a,
			),
			iMod,
		)
	} else {
		// m = (y1 - y2) * cur.inverseMod(x1 - x2, cur.p)
		if iMod, err = cur.InverseMod(Sub(x1, x2), cur.p); err != nil {
			return
		}
		//log.Println("imod", iMod)
		m = Mul(
			Sub(y1, y2),
			iMod,
		)
	}
	return
}

func (s *Point) Add(pt *Point) (rpt *Point, err error) {
	if err = s.IsValid(); err != nil {
		err, rpt = nil, pt.Copy()
	} else if err = pt.IsValid(); err != nil {
		err, rpt = nil, s.Copy()
	} else if s.x.Cmp(pt.x) == 0 && s.y.Cmp(pt.y) != 0 {
		rpt = s.PointNull()
	} else {
		var m *big.Int
		if m, err = s.GetIncline(pt); err != nil {
			return
		}
		// rx = m * m - s.x - pt.x
		//rxry 567678238 13525501245905
		//log.Println("MMMMMMM", m, s.x, s.y, pt.x, pt.y)
		rx := Sub(Sub(Mul(m, m), s.x), pt.x)
		// ry = s.y + m * (rx - s.x)
		ry := Add(s.y, Mul(m, Sub(rx, s.x)))
		//log.Println("rxry", rx, ry)
		rpt = s.InitXY(
			//x = rx % s.curve.p,
			Mod(rx, s.curve.p),
			//y = -ry % s.curve.p,
			Mod(Neg(ry), s.curve.p),
		)
	}
	return
}

// Унарный минус
func (s *Point) Neg() (pt *Point, err error) {
	if err = s.IsValid(); err != nil {
		return
	}
	pt = s.InitXY(
		intCopy(s.x),
		Mod(Neg(s.y), s.curve.p),
	)
	return
}

func (s *Point) Compare(pt *Point) (check bool, err error) {
	if err = s.IsValid(); err != nil {
		return
	}
	if err = pt.IsValid(); err != nil {
		return
	}
	check = s.x.Cmp(pt.x) == 0
	return
}

func (s *Point) Mul(k *big.Int) (pt *Point, err error) {

	if err = s.IsValid(); err == nil {
		if err = s.curve.IsValidN(); err != nil {
			return
		}
	}
	if k.Cmp(big.NewInt(0)) < 0 {
		// k * point = -k * (-point)
		if pt, err = s.Neg(); err == nil {
			pt, err = pt.Mul(Neg(k))
		}
	} else {
		pt = s.PointNull()
		addend := s.Copy()
		k = intCopy(k)
		for k.Cmp(intZero) != 0 {
			if And(k, big.NewInt(1)).Cmp(intZero) > 0 {
				if pt, err = pt.Add(addend); err != nil {
					return
				}
			}
			if addend, err = addend.Add(addend); err != nil {
				return
			}
			// k >>= 1
			k.Rsh(k, 1)
		}
	}
	return
}

/*
    # Умножение
   def __mul__(self, k):
       if self.isNone() or k % self.curve.n == 0:
           return self.pointNull()

       if k < 0:
           # k * point = -k * (-point)
           return -self * -k

       res = self.pointNull()
       addend = self.copy()

       while k:
           if k & 1:
               # Add.
               res = res + addend
           # Double.
           addend = addend + addend
           k >>= 1
       return res

   # Вычисление хэша точки (числовая строка суммы координат)
   def md5(self):
       src = str(self.x + self.y)
       return hashlib.md5(src.encode('utf-8')).digest()

   # Вычисление хэша по оси x
   def md5X(self):
       return hashlib.md5(str(self.x).encode('utf-8')).digest()

   # Вычисление хэша по оси y
   def md5Y(self):
       return hashlib.md5(str(self.y).encode('utf-8')).digest()

   # Проверка совпадения координат 2х точек по осям x и y
   def isEqual(self, point):
       x1, y1 = self.coords()
       x2, y2 = point.coords()
       return x1 == x2 and y1 == y2
*/