package main
import (
"bufio"
"context"
"log"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
/* ======================= Client ======================= */
type Client struct {
id string
conn net.Conn
ctx context.Context
cancel context.CancelFunc
sendCh chan []byte
server *TcpServer
once sync.Once
}
func NewClient(conn net.Conn, server *TcpServer) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
id: conn.RemoteAddr().String(),
conn: conn,
ctx: ctx,
cancel: cancel,
sendCh: make(chan []byte, 100),
server: server,
}
}
func (c *Client) Start() {
go c.readLoop()
go c.writeLoop()
}
func (c *Client) readLoop() {
defer c.cleanup()
reader := bufio.NewReader(c.conn)
for {
select {
case <-c.ctx.Done():
return
default:
c.conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
line, err := reader.ReadString('\n')
if err != nil {
log.Printf("client %s read error: %v", c.id, err)
return
}
line = line[:len(line)-1]
log.Printf("recv from %s: %s", c.id, line)
if line == "quit" {
return
}
c.Send([]byte("echo: " + line + "\n"))
}
}
}
func (c *Client) writeLoop() {
defer c.cleanup()
for {
select {
case <-c.ctx.Done():
return
case data, ok := <-c.sendCh:
if !ok {
return
}
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := c.conn.Write(data); err != nil {
log.Printf("client %s write error: %v", c.id, err)
return
}
}
}
}
func (c *Client) Send(data []byte) {
select {
case c.sendCh <- data:
default:
log.Printf("client %s send buffer full, drop message", c.id)
}
}
func (c *Client) cleanup() {
c.once.Do(func() {
log.Printf("cleanup client %s", c.id)
c.cancel()
c.conn.Close()
close(c.sendCh)
c.server.removeClient(c.id)
})
}
/* ======================= Server ======================= */
type TcpServer struct {
listener net.Listener
clients map[string]*Client
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func NewTcpServer() *TcpServer {
ctx, cancel := context.WithCancel(context.Background())
return &TcpServer{
clients: make(map[string]*Client),
ctx: ctx,
cancel: cancel,
}
}
func (s *TcpServer) Start(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listener = ln
log.Printf("TCP server started on %s", addr)
s.wg.Add(1)
go s.acceptLoop()
return nil
}
func (s *TcpServer) acceptLoop() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
log.Println("accept loop exit")
return
default:
log.Printf("accept error: %v", err)
continue
}
}
client := NewClient(conn, s)
s.mu.Lock()
s.clients[client.id] = client
s.mu.Unlock()
client.Start()
}
}
func (s *TcpServer) removeClient(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.clients, id)
log.Printf("client %s removed", id)
}
func (s *TcpServer) Stop() {
log.Println("server stopping...")
s.cancel()
s.listener.Close()
s.mu.Lock()
for _, c := range s.clients {
c.cancel()
}
s.mu.Unlock()
s.wg.Wait()
log.Println("server stopped gracefully")
}
/* ======================= main ======================= */
func main() {
server := NewTcpServer()
if err := server.Start(":9999"); err != nil {
log.Fatalf("start server failed: %v", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
log.Println("signal received, shutting down")
server.Stop()
}
本文地址:https://www.blear.cn/article/golang-tcp-server
转载时请以链接形式注明出处
评论