前置中间件

func PrepareMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        //处理traceId
        var traceId string
        //从ctx获取grpc的metadata
        md, ok := metadata.FromIncomingContext(ctx)
        if ok {
            vals, ok := md[util.TraceID]
            if ok && len(vals) > 0 {
                traceId = vals[0]
            }
        }
        if len(traceId) == 0 {
            traceId = logs.GenTraceId()
        }
        ctx = logs.WithFieldContext(ctx) // 从ctx拿出field 没有新建一个
        ctx = logs.WithTraceId(ctx, traceId)    // 从traceId放入ctx
        resp, err = next(ctx, req)
        return
    }
}

限流中间件

计数限流

type CounterLimit struct {
    counter      int64 //计数器
    limit        int64 //指定时间窗口内允许的最大请求数
    intervalNano int64 //指定的时间窗口
    unixNano     int64 //unix时间戳,单位为纳秒
}
func NewCounterLimit(interval time.Duration, limit int64) *CounterLimit {
    return &CounterLimit{
        counter:      0,
        limit:        limit,
        intervalNano: int64(interval),
        unixNano:     time.Now().UnixNano(),
    }
}
func (c *CounterLimit) Allow() bool {
    now := time.Now().UnixNano()
    if now-c.unixNano > c.intervalNano { //如果当前过了当前的时间窗口,则重新进行计数
        atomic.StoreInt64(&c.counter, 0)
        atomic.StoreInt64(&c.unixNano, now)
        return true
    }
    atomic.AddInt64(&c.counter, 1)
    return c.counter < c.limit //判断是否要进行限流
}

桶限流

type BucketLimit struct {
    rate       float64 //漏桶中水的漏出速率
    bucketSize float64 //漏桶最多能装的水大小
    unixNano   int64   //unix时间戳
    curWater   float64 //当前桶里面的水
}
func NewBucketLimit(rate float64, bucketSize int64) *BucketLimit {
    return &BucketLimit{
        bucketSize: float64(bucketSize),
        rate:       rate,
        unixNano:   time.Now().UnixNano(),
        curWater:   0,
    }
}
func (b *BucketLimit) reflesh() {
    now := time.Now().UnixNano()
    //时间差, 把纳秒换成秒
    diffSec := float64(now-b.unixNano) / 1000 / 1000 / 1000
    b.curWater = math.Max(0, b.curWater-diffSec*b.rate)
    b.unixNano = now
    return
}
func (b *BucketLimit) Allow() bool {
    b.reflesh()
    if b.curWater < b.bucketSize {
        b.curWater = b.curWater + 1
        return true
    }
    return false
}

令牌桶限流

  • 基于golang.org/x/time/rate包
import "golang.org/x/time/rate"
var limiter = rate.NewLimiter(50,100) // 50为限速QPS 100为桶的大小

中间件实现

type Limiter interface { 
    Allow() bool
}
func NewRateLimitMiddleware(l Limiter) Middleware {
    return func(next MiddlewareFunc) MiddlewareFunc {
        return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
            allow := l.Allow()
            if !allow {
                err = status.Error(codes.ResourceExhausted, "rate limited")
                return
            }
            return next(ctx, req)
        }
    }
}

熔断器中间件

配置

  • 基于github.com/afex/hystrix-go/hystrix包
hystrix.ConfigureCommand("服务名", hystrix.CommandConfig{
    Timeout:    1000,                    // 超时配置,默认1000ms
    MaxConcurrentRequests:    100,         // 并发控制,默认是10
    RequestVolumeThreshold:    500,         // 熔断器打开之后,冷却的时间,默认是500ms
    SleepWindow:30,                    // 一个统计窗口的请求数量,默认是20
    ErrorPercentThreshold:50,            // 失败百分比,默认是50%
})

触发条件

  • 一个统计窗口内,请求数量大于RequestVolumeThreshold,且失败率大于ErrorPercentThreshold, 才会触发熔断

中间件实现

func HystrixMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{},err error) {
        rpcMeta := meta.GetRpcMeta(ctx)
        hystrixErr := hystrix.Do(rpcMeta.ServiceName, func() (err error) {
            resp, err = next(ctx, req)
            return err
        }, nil)
        if hystrixErr != nil {
             err = hystrixErr
            return
        }
        return 
    }
}

服务注册与发现中间件

配置

https://www.cnblogs.com/zhaohaiyu/p/13566315.html

中间件实现

func NewDiscoveryMiddleware(discovery registry.Registry) Middleware {
    return func(next MiddlewareFunc) MiddlewareFunc {
        return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
            //从ctx获取rpc的metadata
            rpcMeta := meta.GetRpcMeta(ctx)
            if len(rpcMeta.AllNodes) > 0 {
                return next(ctx, req)
            }
            service, err := discovery.GetService(ctx, rpcMeta.ServiceName)
            if err != nil {
                logs.Error(ctx, "discovery service:%s failed, err:%v", rpcMeta.ServiceName, err)
                return
            }
            rpcMeta.AllNodes = service.Nodes
            resp, err = next(ctx, req)
            return
        }
    }
}

短连接中间件

配置

type RpcMeta struct {
    //调用方名字
    Caller string
    //服务提供方
    ServiceName string
    //调用的方法
    Method string
    //调用方集群
    CallerCluster string
    //服务提供方集群
    ServiceCluster string
    //TraceID
    TraceID string
    //环境
    Env string
    //调用方IDC
    CallerIDC string
    //服务提供方IDC
    ServiceIDC string
    //当前节点
    CurNode *registry.Node
    //历史选择节点
    HistoryNodes []*registry.Node
    //服务提供方的节点列表
    AllNodes []*registry.Node
    //当前请求使用的连接
    Conn *grpc.ClientConn
}
type rpcMetaContextKey struct{}
func GetRpcMeta(ctx context.Context) *RpcMeta {
    meta, ok := ctx.Value(rpcMetaContextKey{}).(*RpcMeta)
    if !ok {
        meta = &RpcMeta{}
    }
    return meta
}

中间件实现

func ShortConnectMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        //从ctx获取rpc的metadata
        rpcMeta := meta.GetRpcMeta(ctx)
        if rpcMeta.CurNode == nil{
            err = errno.InvalidNode
            logs.Error(ctx, "invalid instance")
            return
        }
        address := fmt.Sprintf("%s:%d", rpcMeta.CurNode.IP, rpcMeta.CurNode.Port)
        conn, err := grpc.Dial(address, grpc.WithInsecure())
        if err != nil {
            logs.Error(ctx, "connect %s failed, err:%v", address, err)
            return nil, errno.ConnFailed
        }
        rpcMeta.Conn = conn
        defer conn.Close()
        resp, err = next(ctx, req)
        return
    }
}

负载均衡中间件

配置

选择

const (
    DefaultNodeWeight = 100
)
type LoadBalanceType int
const (
    LoadBalanceTypeRandom = iota
    LoadBalanceTypeRoundRobin 
)
type LoadBalance interface {
    Name() string
    Select(ctx context.Context, nodes []*registry.Node) (node *registry.Node, err error)
}
func GetLoadBalance(balanceType LoadBalanceType) (balancer LoadBalance) {
    switch (balanceType) {
         case LoadBalanceTypeRandom:
             balancer = NewRandomBalance()
         case LoadBalanceTypeRoundRobin:
             balancer = NewRoundRobinBalance()
         default:
             balancer = NewRandomBalance()
    }
    return
}

randombalance

type RandomBalance struct {
    name string
}
func NewRandomBalance() LoadBalance {
    return &RandomBalance{
        name: "random",
    }
}
func (r *RandomBalance) Name() string {
    return r.name
}
func (r *RandomBalance) Select(ctx context.Context, nodes []*registry.Node) (node *registry.Node, err error) {
    if len(nodes) == 0 {
        err = errno.NotHaveInstance
        return
    }
    defer func() {
        if node != nil {
            setSelected(ctx, node)
        }
    }()
    var newNodes  = filterNodes(ctx, nodes)
    if len(newNodes) == 0 {
        err = errno.AllNodeFailed
        return
    }
    var totalWeight int
    for _, val := range newNodes {
        if val.Weight == 0 {
            val.Weight = DefaultNodeWeight
        }
        totalWeight += val.Weight
    }
    curWeight := rand.Intn(totalWeight)
    curIndex := -1
    for index, node := range nodes {
        curWeight -= node.Weight
        if curWeight < 0 {
            curIndex = index
            break
        }
    }
    if curIndex == -1 {
        err = errno.AllNodeFailed
        return
    }
    node = nodes[curIndex]
    return
}

roundrobin

type RoundRobinBalance struct {
    name  string
    index int
}
func NewRoundRobinBalance() LoadBalance {
    return &RoundRobinBalance{
        name: "roundrobin",
    }
}
func (r *RoundRobinBalance) Name() string {
    return r.name
}
func (r *RoundRobinBalance) Select(ctx context.Context, nodes []*registry.Node) (node *registry.Node, err error) {
    if len(nodes) == 0 {
        err = errno.NotHaveInstance
        return
    }
    defer func() {
        if node != nil {
            setSelected(ctx, node)
        }
    }()
    var newNodes = filterNodes(ctx, nodes)
    if len(newNodes) == 0 {
        err = errno.AllNodeFailed
        return
    }
    r.index = (r.index + 1) % len(nodes)
    node = nodes[r.index]
    return
}

中间件实现

func NewLoadBalanceMiddleware(balancer loadbalance.LoadBalance) Middleware {
    return func(next MiddlewareFunc) MiddlewareFunc {
        return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
            //从ctx获取rpc的metadata
            rpcMeta := meta.GetRpcMeta(ctx)
            if len(rpcMeta.AllNodes) == 0 {
                err = errno.NotHaveInstance
                logs.Error(ctx, "not have instance")
                return
            }
            //生成loadbalance的上下文,用来过滤已经选择的节点
            ctx = loadbalance.WithBalanceContext(ctx)
            for {
                rpcMeta.CurNode, err = balancer.Select(ctx, rpcMeta.AllNodes)
                if err != nil {
                    return
                }
                logs.Debug(ctx, "select node:%#v", rpcMeta.CurNode)
                rpcMeta.HistoryNodes = append(rpcMeta.HistoryNodes, rpcMeta.CurNode)
                resp, err = next(ctx, req)
                if err != nil {
                    //连接错误的话,进行重试
                    if errno.IsConnectError(err) {
                        continue
                    }
                    return
                }
                break
            }
            return
        }
    }
}

监测中间件

  • 基于prometheus

配置

// 服务端采样打点
type Metrics struct {
    requestCounter *prom.CounterVec
    codeCounter    *prom.CounterVec
    latencySummary *prom.SummaryVec
}
//生成server metrics实例
func NewServerMetrics() *Metrics {
    return &Metrics{
        requestCounter: promauto.NewCounterVec(
            prom.CounterOpts{
                Name: "zhy_server_request_total",
                Help: "Total number of RPCs completed on the server, regardless of success or failure.",
            }, []string{"service", "method"}),
        codeCounter: promauto.NewCounterVec(
            prom.CounterOpts{
                Name: "zhy_server_handled_code_total",
                Help: "Total number of RPCs completed on the server, regardless of success or failure.",
            }, []string{"service", "method", "grpc_code"}),
        latencySummary: promauto.NewSummaryVec(
            prom.SummaryOpts{
                Name:       "zhy_proc_cost",
                Help:       "RPC latency distributions.",
                Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
            },
            []string{"service", "method"},
        ),
    }
}
//生成server metrics实例
func NewRpcMetrics() *Metrics {
    return &Metrics{
        requestCounter: promauto.NewCounterVec(
            prom.CounterOpts{
                Name: "zhy_rpc_call_total",
                Help: "Total number of RPCs completed on the server, regardless of success or failure.",
            }, []string{"service", "method"}),
        codeCounter: promauto.NewCounterVec(
            prom.CounterOpts{
                Name: "zhy_rpc_code_total",
                Help: "Total number of RPCs completed on the server, regardless of success or failure.",
            }, []string{"service", "method", "grpc_code"}),
        latencySummary: promauto.NewSummaryVec(
            prom.SummaryOpts{
                Name:       "zhy_rpc_cost",
                Help:       "RPC latency distributions.",
                Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
            },
            []string{"service", "method"},
        ),
    }
}
func (m *Metrics) IncrRequest(ctx context.Context, serviceName, methodName string) {
    m.requestCounter.WithLabelValues(serviceName, methodName).Inc()
}
func (m *Metrics) IncrCode(ctx context.Context, serviceName, methodName string, err error) {
    st, _ := status.FromError(err)
    m.codeCounter.WithLabelValues(serviceName, methodName, st.Code().String()).Inc()
}
func (m *Metrics) Latency(ctx context.Context, serviceName, methodName string, us int64) {
    m.latencySummary.WithLabelValues(serviceName, methodName).Observe(float64(us))
}

中间件实现

var (
    DefaultServerMetrics = prometheus.NewServerMetrics()
    DefaultRpcMetrics    = prometheus.NewRpcMetrics()
)
func init() {
    go func() {
        http.Handle("/metrics", promhttp.Handler())
        addr := fmt.Sprintf("0.0.0.0:%d", 8888)
        http.ListenAndServe(addr, nil)
    }()
}
// 服务端中间件
func PrometheusServerMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        serverMeta := meta.GetServerMeta(ctx)
        DefaultServerMetrics.IncrRequest(ctx, serverMeta.ServiceName, serverMeta.Method)
        startTime := time.Now()
        resp, err = next(ctx, req)
        DefaultServerMetrics.IncrCode(ctx, serverMeta.ServiceName, serverMeta.Method, err)
        DefaultServerMetrics.Latency(ctx, serverMeta.ServiceName,
            serverMeta.Method, time.S***artTime).Nanoseconds()/1000)
        return
    }
}
// 客户端的中间件
func PrometheusClientMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        rpcMeta := meta.GetRpcMeta(ctx)
        DefaultRpcMetrics.IncrRequest(ctx, rpcMeta.ServiceName, rpcMeta.Method)
        startTime := time.Now()
        resp, err = next(ctx, req)
        DefaultRpcMetrics.IncrCode(ctx, rpcMeta.ServiceName, rpcMeta.Method, err)
        DefaultRpcMetrics.Latency(ctx, rpcMeta.ServiceName,
            rpcMeta.Method, time.S***artTime).Nanoseconds()/1000)
        return
    }
}

分布式追踪中间件

配置

const (
    binHdrSuffix = "-bin"
)
// metadataTextMap extends a metadata.MD to be an opentracing textmap
type metadataTextMap metadata.MD
// Set is a opentracing.TextMapReader interface that extracts values.
func (m metadataTextMap) Set(key, val string) {
    // gRPC allows for complex binary values to be written.
    encodedKey, encodedVal := encodeKeyValue(key, val)
    // The metadata object is a multimap, and previous values may exist, but for opentracing headers, we do not append
    // we just override.
    m[encodedKey] = []string{encodedVal}
}
// ForeachKey is a opentracing.TextMapReader interface that extracts values.
func (m metadataTextMap) ForeachKey(callback func(key, val string) error) error {
    for k, vv := range m {
        for _, v := range vv {
            if decodedKey, decodedVal, err := metadata.DecodeKeyValue(k, v); err == nil {
                if err = callback(decodedKey, decodedVal); err != nil {
                    return err
                }
            } else {
                return fmt.Errorf("failed decoding opentracing from gRPC metadata: %v", err)
            }
        }
    }
    return nil
}
// encodeKeyValue encodes key and value qualified for transmission via gRPC.
// note: copy pasted from private values of grpc.metadata
func encodeKeyValue(k, v string) (string, string) {
    k = strings.ToLower(k)
    if strings.HasSuffix(k, binHdrSuffix) {
        val := base64.StdEncoding.EncodeToString([]byte(v))
        v = string(val)
    }
    return k, v
}
func InitTrace(serviceName, reportAddr, sampleType string, rate float64) (err error) {
    transport := transport.NewHTTPTransport(
        reportAddr,
        transport.HTTPBatchSize(16),
    )
    cfg := &config.Configuration{
        Sampler: &config.SamplerConfig{
            Type:  sampleType,
            Param: rate,
        },
        Reporter: &config.ReporterConfig{
            LogSpans: true,
        },
    }
    r := jaeger.NewRemoteReporter(transport)
    tracer, closer, err := cfg.New(serviceName,
        config.Logger(jaeger.StdLogger),
        config.Reporter(r))
    if err != nil {
        fmt.Printf("ERROR: cannot init Jaeger: %v\n", err)
        return
    }
    _ = closer
    opentracing.SetGlobalTracer(tracer)
    return
}

中间件实现

func TraceServerMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        //从ctx获取grpc的metadata
        md, ok := metadata.FromIncomingContext(ctx)
        if !ok {
            //没有的话,新建一个
            md = metadata.Pairs()
        }
        tracer := opentracing.GlobalTracer()
        parentSpanContext, err := tracer.Extract(opentracing.HTTPHeaders, metadataTextMap(md))
        if err != nil && err != opentracing.ErrSpanContextNotFound {
            logs.Warn(ctx, "trace extract failed, parsing trace information: %v", err)
        }
        serverMeta := meta.GetServerMeta(ctx)
        //开始追踪该方法
        serverSpan := tracer.StartSpan(
            serverMeta.Method,
            ext.RPCServerOption(parentSpanContext),
            ext.SpanKindRPCServer,
        )
        serverSpan.SetTag(util.TraceID, logs.GetTraceId(ctx))
        ctx = opentracing.ContextWithSpan(ctx, serverSpan)
        resp, err = next(ctx, req)
        //记录错误
        if err != nil {
            ext.Error.Set(serverSpan, true)
            serverSpan.LogFields(log.String("event", "error"), log.String("message", err.Error()))
        }
        serverSpan.Finish()
        return
    }
}
func TraceClientMiddleware(next MiddlewareFunc) MiddlewareFunc {
    return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
        tracer := opentracing.GlobalTracer()
        var parentSpanCtx opentracing.SpanContext
        if parent := opentracing.SpanFromContext(ctx); parent != nil {
            parentSpanCtx = parent.Context()
        }
        opts := []opentracing.StartSpanOption{
            opentracing.ChildOf(parentSpanCtx),
            ext.SpanKindRPCClient,
            opentracing.Tag{Key: string(ext.Component), Value: "koala_rpc"},
            opentracing.Tag{Key: util.TraceID, Value: logs.GetTraceId(ctx)},
        }
        rpcMeta := meta.GetRpcMeta(ctx)
        clientSpan := tracer.StartSpan(rpcMeta.ServiceName, opts...)
        md, ok := metadata.FromOutgoingContext(ctx)
        if !ok {
            md = metadata.Pairs()
        }
        if err := tracer.Inject(clientSpan.Context(), opentracing.HTTPHeaders, metadataTextMap(md)); err != nil {
            logs.Debug(ctx, "grpc_opentracing: failed serializing trace information: %v", err)
        }
        ctx = metadata.NewOutgoingContext(ctx, md)
        ctx = metadata.AppendToOutgoingContext(ctx, util.TraceID, logs.GetTraceId(ctx))
        ctx = opentracing.ContextWithSpan(ctx, clientSpan)
        resp, err = next(ctx, req)
        //记录错误
        if err != nil {
            ext.Error.Set(clientSpan, true)
            clientSpan.LogFields(log.String("event", "error"), log.String("message", err.Error()))
        }
        clientSpan.Finish()
        return
    }
}