socket.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package rest_websocket
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "log"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "git.ali33.ru/fcg-xvii/rest"
  11. "github.com/gorilla/websocket"
  12. )
  13. func NewSocket(conn *websocket.Conn, pingEnable bool) *Socket {
  14. ctx, cancel := context.WithCancel(context.Background())
  15. ws := &Socket{
  16. conn: conn,
  17. wait: new(sync.Map),
  18. ctx: ctx,
  19. cancel: cancel,
  20. writeLocker: &sync.Mutex{},
  21. chIn: make(chan *rest.RequestStream, 10),
  22. pingEnable: pingEnable,
  23. }
  24. ws.lastWrite.Store(time.Now().Unix())
  25. go ws.read()
  26. return ws
  27. }
  28. type Socket struct {
  29. ctx context.Context
  30. cancel context.CancelFunc
  31. conn *websocket.Conn
  32. wait *sync.Map
  33. chIn chan *rest.RequestStream
  34. writeLocker *sync.Mutex
  35. idCounter atomic.Int64
  36. lastWrite atomic.Int64
  37. pingEnable bool
  38. }
  39. func (s *Socket) Context() context.Context {
  40. return s.ctx
  41. }
  42. // MessagesIn возвращает канал, в который будут переданы все входящие сообщения (rest.RequestTypeMessage и rest.RequestTypeEvent)
  43. func (s *Socket) MessagesIn() <-chan *rest.RequestStream {
  44. return s.chIn
  45. }
  46. // getID увеличивает счётчик сообщений на единицу и возвраащает результат. используется для маркирования идентификаторами исходящих сообщений
  47. func (s *Socket) getID() int64 {
  48. return s.idCounter.Add(1)
  49. }
  50. // read реализует чтение входящих сообщений
  51. func (s *Socket) read() {
  52. //defer log.Println("work close...")
  53. // контекст
  54. s.ctx, s.cancel = context.WithCancel(context.Background())
  55. defer func() {
  56. s.cancel()
  57. s.conn.Close()
  58. }()
  59. // создаем канал для обработки входящих сообщений
  60. chIn := s.exec()
  61. for {
  62. // Read message from server
  63. mType, r, err := s.conn.NextReader()
  64. if err != nil {
  65. s.cancel()
  66. log.Println(err)
  67. return
  68. }
  69. switch mType {
  70. case websocket.TextMessage, websocket.BinaryMessage:
  71. // Обработка текстового или бинарного сообщения
  72. req, err := rest.ReadRequestStream(r)
  73. if err != nil {
  74. log.Println("data error: ", err)
  75. return
  76. }
  77. log.Println("RESPONSE", req)
  78. chIn <- req
  79. }
  80. }
  81. }
  82. // exec реализует обработку сообщений.
  83. func (s *Socket) exec() chan<- *rest.RequestStream {
  84. ch := make(chan *rest.RequestStream)
  85. go func() {
  86. //defer log.Println("exec close...")
  87. for {
  88. select {
  89. // закрытие контекста
  90. case <-s.ctx.Done():
  91. return
  92. // новое сообщение
  93. case req, ok := <-ch:
  94. if !ok {
  95. return
  96. }
  97. //log.Println("OOOOOOOOOOOOOOOOOOOOOOOOOOO")
  98. switch req.Type {
  99. case rest.RequestTypeIn:
  100. s.chIn <- req
  101. case rest.RequestTypeEvent:
  102. s.chIn <- req
  103. case rest.RequestTypeOut:
  104. log.Println("answer in", req.ID)
  105. ir, check := s.wait.Load(req.ID)
  106. if check {
  107. rreq := ir.(*waitRequest)
  108. s.wait.Delete(rreq.id)
  109. rreq.answerIn <- req
  110. }
  111. }
  112. // чистка просроченных сообщений и отправка пинга (при необходимости)
  113. case <-time.After(time.Second * 10):
  114. // чистим сообщения без ответа по дедлайну
  115. now := time.Now()
  116. s.wait.Range(func(key, val any) bool {
  117. if val.(*waitRequest).timeout.Before(now) {
  118. log.Println("CLEAN...", key)
  119. s.wait.Delete(key)
  120. }
  121. return true
  122. })
  123. // отправляем пинг для проверки, живое соединение или нет
  124. if s.pingEnable && now.Unix()-s.lastWrite.Load() > 10 {
  125. log.Println("PING")
  126. //s.writeLocker.Lock()
  127. err := s.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(time.Second))
  128. //s.writeLocker.Unlock()
  129. if err != nil {
  130. log.Println("ping send error")
  131. s.conn.Close()
  132. return
  133. }
  134. s.lastWrite.Store(time.Now().Unix())
  135. }
  136. }
  137. }
  138. }()
  139. return ch
  140. }
  141. func (s *Socket) nextWriter(messageType int) (io.WriteCloser, error) {
  142. s.writeLocker.Lock()
  143. res, err := s.conn.NextWriter(messageType)
  144. s.writeLocker.Unlock()
  145. s.lastWrite.Store(time.Now().Unix())
  146. return res, err
  147. }
  148. func (s *Socket) SendMessage(req *rest.RequestStream) (ch <-chan *rest.RequestStream, err error) {
  149. switch req.Type {
  150. case rest.RequestTypeIn:
  151. req.ID = s.getID()
  152. clReq := newWaitRequest(req.ID, req.Timeout)
  153. ch = clReq.answer
  154. s.wait.Store(req.ID, clReq)
  155. case rest.RequestTypeEvent, rest.RequestTypeOut:
  156. default:
  157. return nil, errors.New("unexpected request type")
  158. }
  159. var writer io.WriteCloser
  160. if writer, err = s.nextWriter(websocket.BinaryMessage); err != nil {
  161. return
  162. }
  163. err = req.Write(writer)
  164. writer.Close()
  165. return
  166. }
  167. func (s *Socket) Close() {
  168. s.cancel()
  169. s.conn.Close()
  170. }