zRPC提供自定义拦截器功能以满足不同的业务需求

客户端拦截器定义

客户端拦截器需要实现 grpc.UnaryClientInterceptor,定义如下:

type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error

在创建客户端的时候通过zrpc.WithUnaryClientInterceptor进行注册

服务端拦截器定义

服务端拦截器需要实现grpc.UnaryServerInterceptor,定义如下:

type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

通过RpcServer.AddUnaryInterceptors进行注册

自定义拦截器示例

自定义客户端和服务端拦截器,客户端拦截器输出请求方法耗时,服务端拦截器进行简单的限流

客户端代码

package main

import (
    "context"
    "fmt"
    "log"
    "time"

    "test/zrpc/pb"

    "github.com/tal-tech/go-zero/core/discov"
    "github.com/tal-tech/go-zero/zrpc"
    "google.golang.org/grpc"
)

func timeInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    stime := time.Now()
    err := invoker(ctx, method, req, reply, cc, opts...)
    if err != nil {
        return err
    }

    fmt.Printf("调用 %s 方法 耗时: %v\n", method, time.Now().Sub(stime))
    return nil
}

func main() {
    client := zrpc.MustNewClient(zrpc.RpcClientConf{
        Etcd: discov.EtcdConf{
            Hosts: []string{"127.0.0.1:2379"},
            Key:   "hello.rpc",
        },
    }, zrpc.WithUnaryClientInterceptor(timeInterceptor))

    hello := pb.NewGreeterClient(client.Conn())

    var count int
    for {
        reply, err := hello.SayHello(context.Background(), &pb.HelloRequest{Name: fmt.Sprintf("Hanmeimei%d", count)})
        if err != nil {
            log.Fatal(err)
        }
        count++
        log.Println(reply.Message)
        time.Sleep(time.Millisecond * 100)
    }
}

服务端代码

Name: hello.rpc
Log:
  Mode: console
ListenOn: 127.0.0.1:9090
Etcd:
  Hosts:
    - 127.0.0.1:2379
  Key: hello.rpc


package main

import (
    "context"
    "flag"
    "fmt"
    "log"
    "time"

    "test/zrpc/pb"

    "github.com/tal-tech/go-zero/core/conf"
    "github.com/tal-tech/go-zero/zrpc"
    "golang.org/x/time/rate"
    "google.golang.org/grpc"
)

type Config struct {
    zrpc.RpcServerConf
}

var cfgFile = flag.String("f", "./server.yaml", "cfg file")

var limiter = rate.NewLimiter(rate.Limit(100), 100)

func main() {
    flag.Parse()

    var cfg Config
    conf.MustLoad(*cfgFile, &cfg)

    srv := zrpc.MustNewServer(cfg.RpcServerConf, func(s *grpc.Server) {
        pb.RegisterGreeterServer(s, &Hello{})
    })
    srv.AddUnaryInterceptors(rateLimitInterceptor)
    srv.Start()
}

type Hello struct{}

func (h *Hello) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
    time.Sleep(time.Millisecond * 300)
    return &pb.HelloReply{Message: "hello " + in.Name}, nil
}

func rateLimitInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    if !limiter.Allow() {
        fmt.Println("限流了")
        return nil, nil
    }
    return handler(ctx, req)
}

运行代码可以看到客户端输出如下信息:

调用 /pb.Greeter/SayHello 方法 耗时: 300.633257ms

当请求速率加快服务端会输出:

限流了

以上输出说明自定义的拦截器已经可以正常的工作了