代码小记---hibernate参数校验,自定义枚举校验注解及消息中添加自定义参数

注解定义:

package com.example.demo.validate;

import com.alibaba.fastjson.JSON;
import org.apache.commons.beanutils.PropertyUtils;
import org.apache.commons.lang3.EnumUtils;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.validator.constraintvalidation.HibernateConstraintValidatorContext;
import org.springframework.util.CollectionUtils;

import javax.validation.Constraint;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import javax.validation.Payload;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static java.lang.annotation.RetentionPolicy.RUNTIME;

@Target(value = {ElementType.FIELD, ElementType.PARAMETER})
@Retention(RUNTIME)
@Constraint(validatedBy = {EnumCheck.StringConstraintValidator.class})
public @interface EnumCheck {
    Class<?>[] groups() default {};

    Class<? extends Payload>[] payload() default {};

    String message() default "{org.sang.enum.check}";

    Class clazz();

    String property();

    class StringConstraintValidator implements ConstraintValidator<EnumCheck, String> {
        private Set<String> allowValues = new HashSet<>();

        @Override
        public void initialize(EnumCheck constraintAnnotation) {
            Class clazz = constraintAnnotation.clazz();
            String property = constraintAnnotation.property();
            if (clazz == null || !clazz.isEnum() || StringUtils.isBlank(property)) {
                return;
            }
            List<Enum> enums = EnumUtils.getEnumList(clazz);
            if (CollectionUtils.isEmpty(enums)) {
                return;
            }
            for (Enum item : enums) {
                String value = null;
                try {
                    String enumName = item.name();
                    if (item.getClass().getField(enumName).getAnnotation(Illegal.class) != null) {
                        continue;
                    }
                    if ("Enum.name".equalsIgnoreCase(property)) {
                        value = enumName;
                    } else {
                        Object field = PropertyUtils.getProperty(item, property);
                        if (field != null) {
                            value = field.toString();
                        }
                    }
                    allowValues.add(value);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }

        }

        @Override
        public boolean isValid(String value, ConstraintValidatorContext context) {
            HibernateConstraintValidatorContext ctxt = context.unwrap(HibernateConstraintValidatorContext.class);
            ctxt.addMessageParameter("items", JSON.toJSONString(allowValues));
            return allowValues.contains(value);
        }
    }

    @interface Illegal {
    }
}

申明一个枚举:

package com.example.demo.validate;

import lombok.Getter;

@Getter
public enum MyEnum {
    A("aa"),
    B("bb"),
    @EnumCheck.Illegal
    C("cc");
    private String id;

    MyEnum(String id) {
        this.id = id;
    }
}

测试使用注解:

package com.example.demo.validate;

import lombok.Data;
import org.hibernate.validator.HibernateValidator;

import javax.validation.Validation;
import javax.validation.Validator;

@Data
public class User {
    @EnumCheck(clazz = MyEnum.class, property = "id")
    private String id;

    public static void main(String[] args) {
        Validator validatorAll = Validation.byProvider(HibernateValidator.class).configure().failFast(false).buildValidatorFactory().getValidator();
        User user = new User();
        user.setId("123");
        System.out.println(validatorAll.validate(user));
    }
}

在resources下直接添加国际化文件:
ValidationMessages.properties

org.sang.enum.check=must be one of {items}

ValidationMessages_zh_CN.properties

org.sang.enum.check=必须是{items}之一

===================================================================
梦想还是要有的,万一实现了呢!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,以下是使用PyTorch实现 "Learning a Deep ConvNet for Multi-label Classification with Partial Labels" 论文的示例代码。 ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.transforms import transforms from sklearn.metrics import f1_score from dataset import CustomDataset from model import ConvNet # 设置随机数种子,保证结果可重复 torch.manual_seed(2022) # 定义超参数 epochs = 50 batch_size = 128 learning_rate = 0.001 weight_decay = 0.0001 num_classes = 20 num_labels = 3 # 定义数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_dataset = CustomDataset(root='./data', split='train', transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) test_dataset = CustomDataset(root='./data', split='test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) # 定义模型 model = ConvNet(num_classes=num_classes, num_labels=num_labels) # 定义损失函数和优化器 criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # 训练模型 for epoch in range(epochs): # 训练阶段 model.train() running_loss = 0.0 for i, data in enumerate(train_loader): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() train_loss = running_loss / len(train_loader) # 测试阶段 model.eval() y_true, y_pred = [], [] with torch.no_grad(): for data in test_loader: inputs, labels = data outputs = model(inputs) predicted_labels = torch.round(torch.sigmoid(outputs)) y_true.extend(labels.cpu().numpy()) y_pred.extend(predicted_labels.cpu().numpy()) f1 = f1_score(y_true, y_pred, average='macro') print('[Epoch %d] Train Loss: %.3f, Test F1: %.3f' % (epoch + 1, train_loss, f1)) ``` `CustomDataset` 和 `ConvNet` 分别是数据集类和模型类,需要根据您的具体情况进行实现。在训练阶段,使用 `nn.BCEWithLogitsLoss()` 作为损失函数进行优化。在测试阶段,使用 `sklearn.metrics.f1_score()` 计算 F1 值作为模型评估指标。 希望以上示例代码对您有所帮助!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值