package rest_websocket import ( "context" "errors" "io" "log" "sync" "sync/atomic" "time" "git.ali33.ru/fcg-xvii/rest" "github.com/gorilla/websocket" ) func NewSocket(conn *websocket.Conn, pingEnable bool) *Socket { ctx, cancel := context.WithCancel(context.Background()) ws := &Socket{ conn: conn, wait: new(sync.Map), ctx: ctx, cancel: cancel, writeLocker: &sync.Mutex{}, chIn: make(chan *rest.RequestStream, 10), pingEnable: pingEnable, } ws.lastWrite.Store(time.Now().Unix()) ws.read() return ws } type Socket struct { ctx context.Context cancel context.CancelFunc conn *websocket.Conn wait *sync.Map chIn chan *rest.RequestStream writeLocker *sync.Mutex idCounter atomic.Int64 lastWrite atomic.Int64 pingEnable bool } func (s *Socket) Context() context.Context { return s.ctx } // MessagesIn возвращает канал, в который будут переданы все входящие сообщения (rest.RequestTypeMessage и rest.RequestTypeEvent) func (s *Socket) MessagesIn() <-chan *rest.RequestStream { return s.chIn } // getID увеличивает счётчик сообщений на единицу и возвраащает результат. используется для маркирования идентификаторами исходящих сообщений func (s *Socket) getID() int64 { return s.idCounter.Add(1) } // read реализует чтение входящих сообщений func (s *Socket) read() { //defer log.Println("work close...") // контекст s.ctx, s.cancel = context.WithCancel(context.Background()) // создаем канал для обработки входящих сообщений chIn := s.exec() go func() { defer func() { s.cancel() s.conn.Close() }() for { // Read message from server mType, r, err := s.conn.NextReader() if err != nil { s.cancel() log.Println(err) return } switch mType { case websocket.TextMessage, websocket.BinaryMessage: // Обработка текстового или бинарного сообщения req, err := rest.ReadRequestStream(r) if err != nil { log.Println("data error: ", err) return } //log.Println("REQUEST", req.Request.Command, req.Request.Data.JSONPrettyString()) chIn <- req } } }() } // exec реализует обработку сообщений. func (s *Socket) exec() chan<- *rest.RequestStream { ch := make(chan *rest.RequestStream) go func() { //defer log.Println("exec close...") for { select { // закрытие контекста case <-s.ctx.Done(): return // новое сообщение case req, ok := <-ch: if !ok { return } //log.Println("OOOOOOOOOOOOOOOOOOOOOOOOOOO") switch req.Type { case rest.RequestTypeIn: s.chIn <- req case rest.RequestTypeEvent: s.chIn <- req case rest.RequestTypeOut: //log.Println("answer in", req.ID) ir, check := s.wait.Load(req.ID) if check { rreq := ir.(*waitRequest) s.wait.Delete(rreq.id) rreq.answerIn <- req } } // чистка просроченных сообщений и отправка пинга (при необходимости) case <-time.After(time.Second * 10): // чистим сообщения без ответа по дедлайну now := time.Now() s.wait.Range(func(key, val any) bool { if val.(*waitRequest).timeout.Before(now) { log.Println("CLEAN...", key) s.wait.Delete(key) } return true }) // отправляем пинг для проверки, живое соединение или нет if s.pingEnable && now.Unix()-s.lastWrite.Load() > 10 { //log.Println("PING") //s.writeLocker.Lock() err := s.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(time.Second)) //s.writeLocker.Unlock() if err != nil { log.Println("ping send error") s.conn.Close() return } s.lastWrite.Store(time.Now().Unix()) } } } }() return ch } func (s *Socket) nextWriter(messageType int) (io.WriteCloser, error) { s.writeLocker.Lock() res, err := s.conn.NextWriter(messageType) s.writeLocker.Unlock() s.lastWrite.Store(time.Now().Unix()) return res, err } func (s *Socket) SendMessage(req *rest.RequestStream) (ch <-chan *rest.RequestStream, err error) { switch req.Type { case rest.RequestTypeIn: req.ID = s.getID() clReq := newWaitRequest(req.ID, req.Timeout) ch = clReq.answer s.wait.Store(req.ID, clReq) case rest.RequestTypeEvent, rest.RequestTypeOut: default: return nil, errors.New("unexpected request type") } var writer io.WriteCloser if writer, err = s.nextWriter(websocket.BinaryMessage); err != nil { return } err = req.Write(writer) writer.Close() return } func (s *Socket) Close() { s.cancel() s.conn.Close() }