开发小记 - 用函数式编程优化代码可读性,减少一半行数

前言

本文主要是记录一下用lambda 表达式优化代码的经历,篇幅不长,算是分享我觉得不错的一个小技巧。

话不多说,直接进入正题。

正文

我们先来看这么一段代码:

@Component
public class ConfigCacheHelper {

    private final RedisHelper redisHelper;

    private final IChannelConfigMapper iChannelConfigMapper;

    @Autowired
    public ConfigCacheHelper(RedisHelper redisHelper, IChannelConfigMapper iChannelConfigMapper) {
        this.redisHelper = redisHelper;
        this.iChannelConfigMapper = iChannelConfigMapper;
    }

    public AaaChannelConfig getAaaChannelConfig(String merchantId){
        if (StringUtils.isEmpty(merchantId)){
            throw new IllegalArgumentException("商户号不能为空");
        }
        Object obj = redisHelper.hget(RedisKey.CHANEL_CONFIG, RedisKey.AAA_CHANNEL);
        AaaChannelConfig config;
        if (obj == null){
            config = iChannelConfigMapper.selectAaaChannelConfig(merchantId);
        }
        else {
            Map<String, AaaChannelConfig> map = (Map<String, AaaChannelConfig>)obj;
            config = map.get(merchantId);
        }

        return Objects.requireNonNull(config, "获取Aaa渠道配置为空");
    }

    public BbbChannelConfig getBbbChannelConfig(String merchantId){
        if (StringUtils.isEmpty(merchantId)){
            throw new IllegalArgumentException("商户号不能为空");
        }
        Object obj = redisHelper.hget(RedisKey.CHANEL_CONFIG, RedisKey.BBB_CHANNEL);
        BbbChannelConfig config;
        if (obj == null){
            config = iChannelConfigMapper.selectBbbChannelConfig(merchantId);
        }
        else {
            Map<String, BbbChannelConfig> map = (Map<String, BbbChannelConfig>)obj;
            config = map.get(merchantId);
        }

        return Objects.requireNonNull(config, "获取Bbb渠道配置为空");
    }

    public CccChannelConfig getCccChannelConfig(String merchantId, String posId, String operatorId){
        if (StringUtils.isEmpty(merchantId)){
            throw new IllegalArgumentException("商户号不能为空");
        }
        Object obj = redisHelper.hget(RedisKey.CHANEL_CONFIG, RedisKey.CCC_CHANNEL);
        CccChannelConfig config;
        if (obj == null){
            config = iChannelConfigMapper.selectCccChannelConfig(merchantId, posId, operatorId);
        }
        else {
            Map<String, CccChannelConfig> map = (Map<String, CccChannelConfig>)obj;
            config = map.get(String.format("%s_%s_%s", merchantId, posId, operatorId));
        }

        return Objects.requireNonNull(config, "获取Ccc渠道配置为空");
    }

    // ... 此处再省略N个渠道的config

}

俺是做支付的,这段代码的逻辑很简单,就是获取某个支付渠道的商户配置,缓存取不到就去数据库取。

在IDEA乍一看,没看出什么问题,代码检查插件也没有报什么warning,但是当我在这个类里面新增获取第N个渠道的方法的时候,我就感觉到了这块代码不是很优雅。

我总结出来两点:

  1. 多余的StringUtils.isEmpty(merchantId)

            if (StringUtils.isEmpty(merchantId)){
                throw new IllegalArgumentException("商户号不能为空");
            }
    

    理由有二:

    • 判断字符串为空应该是调用者的责任
    • 外部的业务逻辑早就确保merchantId 不可能为空字符串,没必要再判断
  2. getXXXChannelConfig逻辑可以提取成如下

        public AaaChannelConfig getAaaChannelConfig(String merchantId){
            // if (StringUtils.isEmpty(merchantId)){
            //     throw new IllegalArgumentException("商户号不能为空");
            // }
            // 1️⃣
            Object obj = redisHelper.hget(RedisKey.CHANEL_CONFIG, {渠道键值});
            AaaChannelConfig config;
            if (obj == null){
                // 2️⃣
                // selectBbbChannelConfig()
                // selectCccChannelConfig(merchantId, posId, operatorId)
                config = iChannelConfigMapper.{取某个渠道配置的方法}(...);
            }
            else {
                Map<String, AaaChannelConfig> map = (Map<String, AaaChannelConfig>)obj;
                // 3️⃣
                // config = map.get(String.format("%s_%s_%s", merchantId, posId, operatorId));
                // config = map.get(merchantId);
                config = map.get({渠道配置Map的键值});
            }
    
            return Objects.requireNonNull(config, "获取Aaa渠道配置为空");
        }
    

第一点很简单,不讲了。主要来讲下第二点,从上面的分析中,就可以抽取出3个变量:

  • 渠道键值
  • 去数据库中取配置的函数 - 配置的提供者
  • 渠道配置Map的键值

这样子我们就可以把代码改造成下面这样:

    
    private <T> T getChannelConfig(String configKey, String configMapInnerKey,
                                   Supplier<T> daoSupplier) {
        Object obj = redisHelper.hget(RedisKey.CHANNELPROXY_HKEY, configKey);
        T config;
        if (obj == null){
            config = daoSupplier.get();
        } else {
            Map<String, T> map = (Map<String, T>)obj;
            config = map.get(configMapInnerKey);
        }
        return Objects.requireNonNull(config, "获取渠道配置为空, 渠道值:" + configKey);
    }

    public AaaChannelConfig getAaaChannelConfig(String merchantId){
        return getChannelConfig(
            RedisKey.AAA_CHANNEL, merchantId,
            () -> iChannelConfigMapper.selectAaaChannelConfig(merchantId)
        );
    }

    public BbbChannelConfig getBbbChannelConfig(String merchantId){
        return getChannelConfig(
            RedisKey.BBB_CHANNEL, merchantId,
            () -> iChannelConfigMapper.selectBbbChannelConfig(merchantId)
        );
    }

    public CccChannelConfig getCccChannelConfig(String merchantId, String posId, String operatorId){
        return getChannelConfig(
            RedisKey.CCC_CHANNEL, String.format("%s_%s_%s", merchantId, posId, operatorId),
            () -> iChannelConfigMapper.selectCccChannelConfig(merchantId, posId, operatorId)
        );
    }

这里简单提一下Supplier,这是java.util.function中提供的函数式接口,用来支持Java 中的函数式编程。从语义上理解就是“T的提供者”,比如在上文语境中就是对应渠道配置的提供者。类似的常用接口还有:

接口参数返回类型
Predicate<T>Tboolean
Consumer<T>Tvoid
Function<T,R>TR
Supplier<T>NoneT
UnaryOperator<T>TT

优化结果

这样优化之后,提升了代码的可读性,在实现相同功能的前提下,比原来减少了一半的代码量。(233 -> 137)

如何抛出我想要的异常?

这里想要再提一个场景,因为之前有碰到过,就是如何让你的Function抛出异常?

还是拿上述代码为例,

    private <T> T getChannelConfig(String configKey, String configMapInnerKey,
                                   Supplier<T> daoSupplier) {
        Object obj = redisHelper.hget(RedisKey.CHANNELPROXY_HKEY, configKey);
        T config;
        if (obj == null){
            // 假设我想让这个方法抛出一个自定义的BizException 业务异常,怎么办?
            config = daoSupplier.get();
        } else {
            Map<String, T> map = (Map<String, T>)obj;
            config = map.get(configMapInnerKey);
        }
        return Objects.requireNonNull(config, "获取渠道配置为空, 渠道值:" + configKey);
    }

我刚开始想了蛮久的,后面发现这实际上是属于Java 基础方面的知识。

我们传入的参数daoSuppier实际上相当于是一个函数式接口Supplier的匿名内部实现类而已(当然底层实现是不一样的,在某种意义上比匿名内部类好很多,无论是性能,可读性还是使用趋势)

@FunctionalInterface
public interface Supplier<T> {

    /**
     * Gets a result.
     *
     * @return a result
     */
    T get();
}

接口可以声明抛出某个异常,它的实现可以不抛出异常,反之呢,如果它的实现抛出了受检异常,这个接口就必须显式声明抛出这个异常。

所以这种情况下,我们如果想我们的Supplier抛出我们想要的异常,那么就需要自己声明一个Functional Interface

    public interface DaoSupplier<T> {

        /**
         * Gets a result.
         *
         * @return a result
         */
        T get() throws BizException;
    }

再次改造后的代码:


    private <T> T getChannelConfig(String configKey, String configMapInnerKey,
                                   DaoSupplier<T> daoSupplier) throws BizException {
        Object obj = redisHelper.hget(RedisKey.CHANNELPROXY_HKEY, configKey);
        T config;
        if (obj == null){
            config = daoSupplier.get();
        } else {
            Map<String, T> map = (Map<String, T>)obj;
            config = map.get(configMapInnerKey);
        }
        return Objects.requireNonNull(config, "获取渠道配置为空, 渠道值:" + configKey);
    }
    
    public AaaChannelConfig getAaaPayChannelConfig(String merchantId) throws BizException {
        return getChannelConfig(
                RedisKey.AAA_CHANNEL, merchantId,
                () -> Optional.ofNullable(iChannelConfigMapper.selectAliPayChannelConfig(appId))
                        .orElseThrow(BizException::new)
        );
    }

结语

本文并没有引入很多的Java Lambda的原理性介绍、API介绍,因为本身就是个人的开发小记,偏重于实践,引入太多的知识性介绍反而偏离了本意。希望Java 函数式不熟悉的同学可以自己学习下相关资料。

如果本文有帮助到你,希望能点个赞,这是对我的最大动力????????????????。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,以下是使用PyTorch实现 "Learning a Deep ConvNet for Multi-label Classification with Partial Labels" 论文的示例代码。 ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.transforms import transforms from sklearn.metrics import f1_score from dataset import CustomDataset from model import ConvNet # 设置随机数种子,保证结果可重复 torch.manual_seed(2022) # 定义超参数 epochs = 50 batch_size = 128 learning_rate = 0.001 weight_decay = 0.0001 num_classes = 20 num_labels = 3 # 定义数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_dataset = CustomDataset(root='./data', split='train', transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) test_dataset = CustomDataset(root='./data', split='test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) # 定义模型 model = ConvNet(num_classes=num_classes, num_labels=num_labels) # 定义损失函数和优化器 criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # 训练模型 for epoch in range(epochs): # 训练阶段 model.train() running_loss = 0.0 for i, data in enumerate(train_loader): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() train_loss = running_loss / len(train_loader) # 测试阶段 model.eval() y_true, y_pred = [], [] with torch.no_grad(): for data in test_loader: inputs, labels = data outputs = model(inputs) predicted_labels = torch.round(torch.sigmoid(outputs)) y_true.extend(labels.cpu().numpy()) y_pred.extend(predicted_labels.cpu().numpy()) f1 = f1_score(y_true, y_pred, average='macro') print('[Epoch %d] Train Loss: %.3f, Test F1: %.3f' % (epoch + 1, train_loss, f1)) ``` `CustomDataset` 和 `ConvNet` 分别是数据集类和模型类,需要根据您的具体情况进行实现。在训练阶段,使用 `nn.BCEWithLogitsLoss()` 作为损失函数进行优化。在测试阶段,使用 `sklearn.metrics.f1_score()` 计算 F1 值作为模型评估指标。 希望以上示例代码对您有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值