springboot集成netty实现代理服务器
说明
使用netty实现代理服务功能,思路是:客户端发送请求,由netty服务端通过端口监听到请求,然后在内部再开启一个netty客户端作为代理去访问真实的服务器,最后由真实的服务器将响应返回给代理,代理再返回给netty服务端,最后返回给浏览器。
目前实现了http和https的代理。
导入依赖
<dependencies>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-buffer</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler-proxy</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>1.58</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
配置
websync.port=9999
容器加载后启动程序
@Component
public class Runner implements ApplicationRunner {
@Value("${websync.port}")
private int port;
@Override
public void run(ApplicationArguments args) throws Exception {
new Server(port).start();
}
}
服务端
public class Server {
public final static HttpResponseStatus SUCCESS = new HttpResponseStatus(200,
"Connection established");
private final int PORT;
private final EventLoopGroup workerStateEvent = new NioEventLoopGroup();
private final EventLoopGroup bossStateEvent = new NioEventLoopGroup();
private final ServerBootstrap bootstrap = new ServerBootstrap();
private final ServerHandler serverHandler = new ServerHandler();
public Server(int PORT) {
this.PORT = PORT;
}
public void start() throws InterruptedException {
bootstrap.group(bossStateEvent, workerStateEvent)
.channel(NioServerSocketChannel.class)
.localAddress(new InetSocketAddress(PORT))
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel socketChannel) throws Exception {
socketChannel.pipeline().addLast("httpCodec", new HttpServerCodec());
socketChannel.pipeline().addLast("httpObject", new HttpObjectAggregator(65536));
socketChannel.pipeline().addLast(serverHandler);
}
});
ChannelFuture channel = bootstrap.bind().sync();
//关闭通道
channel.channel().closeFuture().sync();
}
}
服务端handler
//线程间共享,但必须要保证此类线程安全
@ChannelHandler.Sharable
public class ServerHandler extends ChannelInboundHandlerAdapter {
private final static Log LOG = LogFactory.getLog(ServerHandler.class);
//保证线程安全
private ThreadLocal<ChannelFuture> futureThreadLocal = new ThreadLocal<>();
private final AtomicInteger PORT = new AtomicInteger(0);
private final AtomicReference<String> HOST = new AtomicReference<String>("0.0.0.0");
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
LOG.info("服务器连接成功......");
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
//http
if (msg instanceof FullHttpRequest) {
FullHttpRequest request = (FullHttpRequest) msg;
String name = request.method().name();
RequestProto protoUtil = ProtoUtil.getRequestProto(request);
String host = protoUtil.getHost();
int port = protoUtil.getPort();
PORT.set(port);
HOST.set(host);
request.headers().set("11", "222");
if ("CONNECT".equalsIgnoreCase(name)) {//HTTPS建立代理握手
HttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, Server.SUCCESS);
ctx.writeAndFlush(response);
ctx.pipeline().remove("httpCodec");
ctx.pipeline().remove("httpObject");
return;
}
//开启代理服务器
new ProxyServer(host, port, msg, ctx.channel()).start();
} else { //https,只转发数据,不对数据做处理,所以不需要解密密文
ChannelFuture future = futureThreadLocal.get();
//代理连接还未建立
if (future == null) {
//连接至目标服务器
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(ctx.channel().eventLoop()) // 复用客户端连接线程池
.channel(ctx.channel().getClass()) // 使用NioSocketChannel来作为连接用的channel类
.handler(new ChannelInitializer() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx0, Object msg) throws Exception {
ctx.channel().writeAndFlush(msg);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
System.out.println("https 代理服务器连接成功...");
}
});
}
});
future = bootstrap.connect(HOST.get(), PORT.get());
futureThreadLocal.set(future);
future.addListener(new ChannelFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
future.channel().writeAndFlush(msg);
} else {
ctx.channel().close();
}
}
});
} else {
//代理建立连接之后,直接刷回数据
future.channel().writeAndFlush(msg);
}
}
}
}
代理客户端
public class ProxyServer {
private final String HOST;
private final int PORT;
private final Object msg;
private final Channel channel;
public ProxyServer(String HOST, int PORT, Object msg, Channel channel) {
this.HOST = HOST;
this.PORT = PORT;
this.msg = msg;
this.channel = channel;
}
public void start() {
Bootstrap bootstrap = new Bootstrap();
EventLoopGroup group = new NioEventLoopGroup();
bootstrap.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel socketChannel) throws Exception {
socketChannel.pipeline().addLast(new HttpClientCodec());
socketChannel.pipeline().addLast(new HttpObjectAggregator(6553600));
socketChannel.pipeline().addLast(new ProxyServerHandler(channel));
}
})
.connect(new InetSocketAddress(HOST, PORT))
.addListener(new ChannelFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
HeaderUtil.addHeaders(future, msg);
} else {
future.channel().close();
}
}
});
}
}
代理handler
public class ProxyServerHandler extends ChannelInboundHandlerAdapter {
private Channel channel;
public ProxyServerHandler(Channel channel) {
this.channel = channel;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
System.out.println("代理服务器连接成功.....");
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
channel.writeAndFlush(msg);
}
}
工具类
public class HeaderUtil {
/**
* @methodName: addHeaders
* @description: 添加headers信息,响应客户端
* @auther: CemB
* @date: 2018/12/20 17:18
*/
public static void addHeaders(ChannelFuture future, Object request) {
if (request instanceof HttpRequest) {
HttpRequest msg = (FullHttpRequest) request;
msg.headers().set("111", "222");
future.channel().writeAndFlush(msg);
} else {
future.channel().writeAndFlush(request);
}
}
}
public class ProtoUtil {
public static RequestProto getRequestProto(HttpRequest httpRequest) {
RequestProto requestProto = new RequestProto();
int port = -1;
String hostStr = httpRequest.headers().get(HttpHeaderNames.HOST);
if (hostStr == null) {
Pattern pattern = Pattern.compile("^(?:https?://)?(?<host>[^/]*)/?.*$");
Matcher matcher = pattern.matcher(httpRequest.uri());
if (matcher.find()) {
hostStr = matcher.group("host");
} else {
return null;
}
}
String uriStr = httpRequest.uri();
Pattern pattern = Pattern.compile("^(?:https?://)?(?<host>[^:]*)(?::(?<port>\\d+))?(/.*)?$");
Matcher matcher = pattern.matcher(hostStr);
//先从host上取端口号没取到再从uri上取端口号 issues#4
String portTemp = null;
if (matcher.find()) {
requestProto.setHost(matcher.group("host"));
portTemp = matcher.group("port");
if (portTemp == null) {
matcher = pattern.matcher(uriStr);
if (matcher.find()) {
portTemp = matcher.group("port");
}
}
}
if (portTemp != null) {
port = Integer.parseInt(portTemp);
}
boolean isSsl = uriStr.indexOf("https") == 0 || hostStr.indexOf("https") == 0;
if (port == -1) {
if (isSsl) {
port = 443;
} else {
port = 80;
}
}
requestProto.setPort(port);
requestProto.setSsl(isSsl);
return requestProto;
}
public static class RequestProto implements Serializable {
private static final long serialVersionUID = -6471051659605127698L;
private String host;
private int port;
private boolean ssl;
public RequestProto() {
}
public RequestProto(String host, int port, boolean ssl) {
this.host = host;
this.port = port;
this.ssl = ssl;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public int getPort() {
return port;
}
public void setPort(int port) {
this.port = port;
}
public boolean getSsl() {
return ssl;
}
public void setSsl(boolean ssl) {
this.ssl = ssl;
}
}
}