golang实现tcp服务器


package main

import (
    "bufio"
    "context"
    "fmt"
    "log"
    "net"
    "os"
    "os/signal"
    "sync"
    "syscall"
    "time"
)

// Client 表示一个 TCP 客户端连接
type Client struct {
    id     string
    conn   net.Conn
    ctx    context.Context
    cancel context.CancelFunc
    sendCh chan []byte
    server *TcpServer

    once sync.Once
}

func NewClient(conn net.Conn, server *TcpServer) *Client {
    ctx, cancel := context.WithCancel(context.Background())
    return &Client{
        id:     conn.RemoteAddr().String(),
        conn:   conn,
        ctx:    ctx,
        cancel: cancel,
        sendCh: make(chan []byte, 100),
        server: server,
    }
}

func (c *Client) Start() {
    // 启动读写协程
    go c.readLoop()
    go c.writeLoop()
}

func (c *Client) readLoop() {
    defer c.cleanup()

    reader := bufio.NewReader(c.conn)
    for {
        select {
        case <-c.ctx.Done():
            return
        default:
            c.conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
            line, err := reader.ReadString('\n')
            if err != nil {
                log.Printf("client %s read error: %v", c.id, err)
                return
            }
            line = line[:len(line)-1] // 去掉换行符
            log.Printf("recv from %s: %s", c.id, line)

            // 示例: 收到 "quit" 关闭连接
            if line == "quit" {
                log.Printf("client %s requested quit", c.id)
                return
            }

            // 简单回显
            c.Send([]byte("echo: " + line + "\n"))
        }
    }
}

func (c *Client) writeLoop() {
    defer c.cleanup()

    for {
        select {
        case <-c.ctx.Done():
            return
        case data, ok := <-c.sendCh:
            if !ok {
                return
            }
            c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            _, err := c.conn.Write(data)
            if err != nil {
                log.Printf("client %s write error: %v", c.id, err)
                return
            }
        }
    }
}

// Send 发送数据到客户端(非阻塞)
func (c *Client) Send(data []byte) {
    select {
    case c.sendCh <- data:
    default:
        log.Printf("client %s send channel full, drop message", c.id)
    }
}

// cleanup 关闭连接,释放资源,通知服务器移除client
func (c *Client) cleanup() {
    c.once.Do(func() {
        log.Printf("cleaning up client %s", c.id)
        c.cancel()
        c.conn.Close()
        close(c.sendCh)
        c.server.removeClient(c.id)
    })
}

// ------------------------------------------------------

// TcpServer 表示 TCP 服务器
type TcpServer struct {
    listener net.Listener
    clients  map[string]*Client
    mu       sync.Mutex
    wg       sync.WaitGroup
    ctx      context.Context
    cancel   context.CancelFunc
}

func NewTcpServer() *TcpServer {
    ctx, cancel := context.WithCancel(context.Background())
    return &TcpServer{
        clients: make(map[string]*Client),
        ctx:     ctx,
        cancel:  cancel,
    }
}

// Start 启动 TCP 服务器监听
func (s *TcpServer) Start(address string) error {
    var err error
    s.listener, err = net.Listen("tcp", address)
    if err != nil {
        return err
    }

    log.Printf("server started on %s", address)

    s.wg.Add(1)
    go s.acceptLoop()

    return nil
}

// acceptLoop 循环接受新连接
func (s *TcpServer) acceptLoop() {
    defer s.wg.Done()

    for {
        conn, err := s.listener.Accept()
        if err != nil {
            select {
            case <-s.ctx.Done():
                // 服务器关闭了,退出循环
                log.Println("listener closed, stopping accept loop")
                return
            default:
                log.Printf("accept error: %v", err)
                continue
            }
        }

        client := NewClient(conn, s)

        s.mu.Lock()
        s.clients[client.id] = client
        s.mu.Unlock()

        s.wg.Add(1)
        go func() {
            defer s.wg.Done()
            client.Start()
            <-client.ctx.Done()
            log.Printf("client %s goroutine exit", client.id)
        }()
    }
}

// removeClient 从 map 中删除客户端
func (s *TcpServer) removeClient(id string) {
    s.mu.Lock()
    defer s.mu.Unlock()
    delete(s.clients, id)
    log.Printf("client %s removed from server", id)
}

// Stop 优雅关闭服务器,关闭监听,断开所有客户端,等待所有 goroutine 退出
func (s *TcpServer) Stop() {
    log.Println("stopping server...")
    s.cancel()
    s.listener.Close()

    s.mu.Lock()
    for _, client := range s.clients {
        client.cleanup()
    }
    s.mu.Unlock()

    s.wg.Wait()
    log.Println("server stopped gracefully")
}

// ------------------------------------------------------

func main() {
    server := NewTcpServer()
    err := server.Start(":9999")
    if err != nil {
        log.Fatalf("failed to start server: %v", err)
    }

    // 监听系统信号用于优雅退出
    stopChan := make(chan os.Signal, 1)
    signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM)

    <-stopChan
    log.Println("received interrupt signal, shutting down...")
    server.Stop()
}

本文地址:https://www.blear.cn/article/golang-tcp-server

转载时请以链接形式注明出处

评论
受监管部门要求,个人网站不允许评论功能,评论已关闭,抱歉!