package main
import (
"bufio"
"context"
"fmt"
"log"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
// Client 表示一个 TCP 客户端连接
type Client struct {
id string
conn net.Conn
ctx context.Context
cancel context.CancelFunc
sendCh chan []byte
server *TcpServer
once sync.Once
}
func NewClient(conn net.Conn, server *TcpServer) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
id: conn.RemoteAddr().String(),
conn: conn,
ctx: ctx,
cancel: cancel,
sendCh: make(chan []byte, 100),
server: server,
}
}
func (c *Client) Start() {
// 启动读写协程
go c.readLoop()
go c.writeLoop()
}
func (c *Client) readLoop() {
defer c.cleanup()
reader := bufio.NewReader(c.conn)
for {
select {
case <-c.ctx.Done():
return
default:
c.conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
line, err := reader.ReadString('\n')
if err != nil {
log.Printf("client %s read error: %v", c.id, err)
return
}
line = line[:len(line)-1] // 去掉换行符
log.Printf("recv from %s: %s", c.id, line)
// 示例: 收到 "quit" 关闭连接
if line == "quit" {
log.Printf("client %s requested quit", c.id)
return
}
// 简单回显
c.Send([]byte("echo: " + line + "\n"))
}
}
}
func (c *Client) writeLoop() {
defer c.cleanup()
for {
select {
case <-c.ctx.Done():
return
case data, ok := <-c.sendCh:
if !ok {
return
}
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_, err := c.conn.Write(data)
if err != nil {
log.Printf("client %s write error: %v", c.id, err)
return
}
}
}
}
// Send 发送数据到客户端(非阻塞)
func (c *Client) Send(data []byte) {
select {
case c.sendCh <- data:
default:
log.Printf("client %s send channel full, drop message", c.id)
}
}
// cleanup 关闭连接,释放资源,通知服务器移除client
func (c *Client) cleanup() {
c.once.Do(func() {
log.Printf("cleaning up client %s", c.id)
c.cancel()
c.conn.Close()
close(c.sendCh)
c.server.removeClient(c.id)
})
}
// ------------------------------------------------------
// TcpServer 表示 TCP 服务器
type TcpServer struct {
listener net.Listener
clients map[string]*Client
mu sync.Mutex
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func NewTcpServer() *TcpServer {
ctx, cancel := context.WithCancel(context.Background())
return &TcpServer{
clients: make(map[string]*Client),
ctx: ctx,
cancel: cancel,
}
}
// Start 启动 TCP 服务器监听
func (s *TcpServer) Start(address string) error {
var err error
s.listener, err = net.Listen("tcp", address)
if err != nil {
return err
}
log.Printf("server started on %s", address)
s.wg.Add(1)
go s.acceptLoop()
return nil
}
// acceptLoop 循环接受新连接
func (s *TcpServer) acceptLoop() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
// 服务器关闭了,退出循环
log.Println("listener closed, stopping accept loop")
return
default:
log.Printf("accept error: %v", err)
continue
}
}
client := NewClient(conn, s)
s.mu.Lock()
s.clients[client.id] = client
s.mu.Unlock()
s.wg.Add(1)
go func() {
defer s.wg.Done()
client.Start()
<-client.ctx.Done()
log.Printf("client %s goroutine exit", client.id)
}()
}
}
// removeClient 从 map 中删除客户端
func (s *TcpServer) removeClient(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.clients, id)
log.Printf("client %s removed from server", id)
}
// Stop 优雅关闭服务器,关闭监听,断开所有客户端,等待所有 goroutine 退出
func (s *TcpServer) Stop() {
log.Println("stopping server...")
s.cancel()
s.listener.Close()
s.mu.Lock()
for _, client := range s.clients {
client.cleanup()
}
s.mu.Unlock()
s.wg.Wait()
log.Println("server stopped gracefully")
}
// ------------------------------------------------------
func main() {
server := NewTcpServer()
err := server.Start(":9999")
if err != nil {
log.Fatalf("failed to start server: %v", err)
}
// 监听系统信号用于优雅退出
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)
<-stopChan
log.Println("received interrupt signal, shutting down...")
server.Stop()
}
本文地址:https://www.blear.cn/article/golang-tcp-server
转载时请以链接形式注明出处
评论