K-Means聚类算法JAVA实现


前言

善始者繁多,克终者盖寡。

K-Means是用途极为广泛的聚类算法,因为其操作简单、易于实现的特点,它也是应用最多的算法之一,今天从K-Means算法的原理入手,使用JAVA实现K-Means聚类算法。

一、聚类与分类的区别

聚类算法属于无监督算法目标类别未知,常见的聚类算法有K-Means、DBSCAN等;
分类算法属于有监督算法目标类别已知,常见的分类算法有决策树、朴素贝叶斯、支持向量机等。

二、K-Means聚类过程

K-Mean实现步骤如下:

①设定类别数K和聚类迭代次数,在所有样本构成的样本空间中选择K个初始点作为初始聚类中心,初始聚类中心可以是某个样本,也可以为样本空间中的任意点
②计算所有样本到K个聚类中心的距离,根据距离最小原则将所有样本划分到不同的类别中,样本到聚类中心的距离通常使用欧式距离表示。给定聚类中心C=(c_1, c_2,…, c_n)和样本X=(x_1, x_2,…, x_n),n为样本的属性个数,样本到聚类中心的距离Dis(cx)表示为:

在这里插入图片描述

③将初始聚类中所有样本到聚类中心距离的均值作为新的聚类中心,给定某个聚类中样本的集合X=(x_1, x_2,…, x_m),m表示样本个数,其中x_i=(x_i1, x_i2,…, x_in),i表示样本集合中的第i个样本,n表示样本的属性个数,新的聚类中心C可用向量表示为:

在这里插入图片描述

④重复步骤②和③,直至各聚类中心不再改变或者达到最大迭代次数。

上述操作看起来很复杂,其实很简单,就是依次计算各样本到聚类中心的距离,把距离小的样本都放在一个类别中,在根据这个类别中的样本计算出新的聚类中心,使用的方法就是求“均值”。

三、JAVA实现

3.1 变量说明

    private int k; //聚类数目
    private int m; //最大迭代次数
    private int dataLength; //数据集中数据的个数
    private ArrayList<double[]> data; //数据集
    private ArrayList<double[]> center; //聚类中心,结构与各数据点相同
    private ArrayList<ArrayList<double[]>> cluster; //聚类形成的簇
    private ArrayList<Double> SEE; //聚类中系统整体误差平方和
    private int temp; //用于记录最终迭代次数
    private ArrayList<double[]> center_copy; //记录初始聚类中心
    private int DIMENSION; //记录此次数据点的维度

3.2 构造器与GET/SET方法

根据需要,仅为前6个变量设置GET/SET方法。

//空参构造函数
    public MyKmeans(){

    }
    //包含K值的构造函数
    public MyKmeans(int k){
        this.k = k;
    }
    public MyKmeans(int k,int m){
        this.k = k;
        this.m = m;
    }

    public int getM() {
        return m;
    }

    public void setM(int m) {
        this.m = m;
    }

    public int getDataLength() {
        return dataLength;
    }

    public void setDataLength(int dataLength) {
        this.dataLength = dataLength;
    }

    public ArrayList<double[]> getData() {
        return data;
    }

    public void setData(ArrayList<double[]> data) {
        this.data = data;
    }

    public ArrayList<double[]> getCenter() {
        return center;
    }

    public void setCenter(ArrayList<double[]> center) {
        this.center = center;
    }

    public ArrayList<ArrayList<double[]>> getCluster() {
        return cluster;
    }

    public void setCluster(ArrayList<ArrayList<double[]>> cluster) {
        this.cluster = cluster;
    }

3.3 初始化

在系统运行前需对数据集、初始聚类中心、簇等进行初始化,同时还需要对聚类数k和迭代次数m进行检测,聚类数最大不超过数据总个数,最小不低于1。

//初始化聚类,保证程序能够正常运行
    public void init(){
//        //默认情况下迭代次数为10,参考SPSS,暂未实装
//        m=10;
        //读取数据文件
        readData();
        if (data==null || data.size()==0){
            //使用系统自带的初始数据集
            initData();
        }
        dataLength = data.size();
        //判断K的取值,如果聚类数小于0则设为1类,如果大于数据集中元素个数,则设为dataLength个类
        if (k<=0){
            k=1;
        }
        if(k>dataLength){
            k=dataLength;
        }
        //初始化聚类中心,使用随机数实现
        initCenter();
        //初始化聚类结果,此时聚类结果为K个空的簇
        initCluster();
        //初始化聚类中的误差平方和
        initSEE();
    }

3.3.1 初始化数据集

如果为提供数据集,为保证程序正常运行,使用程序默认提供的数据集。

//读取数据文件
    public void readData(){
        data = new ArrayList<double[]>();
        FileInputStream fileInputStream = null;
        InputStreamReader inputStreamReader = null;
        BufferedReader bufferedReader = null;
        try {
            fileInputStream = new FileInputStream(new File("./src/data.txt"));
             inputStreamReader = new InputStreamReader(fileInputStream);
             bufferedReader = new BufferedReader(inputStreamReader);
            String str = null;
            while ((str = bufferedReader.readLine()) != null){
                //获取每一行数据,创建一个一维数组暂时存储这些数据
                int len = str.split(",").length;
                double[] temp_data = new double[len];
                for (int i = 0; i < temp_data.length; i++) {
                    temp_data[i] = Double.parseDouble(str.split(",")[i]);
                }
                data.add(temp_data);
                DIMENSION = len;
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.out.println("未找到指定文件!");
        } catch (IOException e) {
            e.printStackTrace();
            System.out.println("打开文件出错!");
        }finally {
            try {
                bufferedReader.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            try {
                inputStreamReader.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            try {
                fileInputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
    //当没有读取本地文件时,使用系统自动的初始数据集,初始数据集为10个二维平面中的点
    public void initData(){
        DIMENSION = 2;
        data = new ArrayList<double[]>();
        double[][] default_data = new double[][]{
                {0,0},
                {1,1},
                {2,2},
                {3,3},
                {4,4},
                {5,5},
                {6,6},
                {7,7},
                {8,8},
                {9,9}
        };
        for (int i = 0; i < default_data.length; i++) {
            data.add(default_data[i]);
        }
    }

3.3.2 初始化聚类中心

本人选择从原始数据集中选择初始聚类中心,此方法可能导致初始聚类中心重复!!!

//显示聚类中心
    public void show_Cneter(ArrayList<double[]> center){
        for (int i  = 0; i < center.size(); i++) {
            System.out.print("[");
            for (int j = 0; j < DIMENSION; j++) {
                System.out.print(center.get(i)[j]);
                if (j!= DIMENSION-1)
                    System.out.print(",");
            }
            System.out.print("]\t");
        }
    }

    //初始化聚类中心,使用随机数生成
    public void initCenter(){
        center = new ArrayList<double[]>();
        for (int i = 0; i <k; i++) {
            Random random = new Random();
            double[] randoms_center;
            //随机指定数据点作为初始中心,可换用其他方法生成随机初始中心
            int index = random.nextInt(dataLength);
            randoms_center = data.get(index);
            center.add(randoms_center);
        }
        //保存初始聚类中心的副本
        //center_copy = center; 此种方法万万不可写
        center_copy = new ArrayList<>();
        center_copy.addAll(center);
        System.out.print("初始聚类中心是:");
        show_Cneter(center_copy);
        System.out.println();
        System.out.println();
    }

3.3.3 初始化簇

 //初始化聚类结果,此时聚类中包含K个孔的簇
    public void initCluster(){
        cluster = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            cluster.add(new ArrayList<>());
        }
    }

    //显示每个簇中的元素
    public void show_ClusterData(ArrayList<ArrayList<double[]>> cluster){
        for (int i = 0; i < cluster.size(); i++) {
            System.out.print("类别"+(i+1)+"包含元素:");
            for (int j = 0; j < cluster.get(i).size(); j++) {
                System.out.print("[");
                for (int index = 0;index < DIMENSION; index++){
                    System.out.print(cluster.get(i).get(j)[index]);
                    if (index!= DIMENSION-1)
                        System.out.print(",");
                }
                System.out.print("]\t");
            }
            System.out.println();
        }
    }

3.3.4 初始化SEE

SEE就是系统的误差,程序停止运行的条件是:

①达到最大迭代次数;
②程序误差不再改变,即SEE的值不再改变。

    //初始化聚类中的误差平方和
    public void initSEE(){
        SEE = new ArrayList<>();
    }

3.4 计算两点间距离

/**
     * 计算两个点之间的距离
     * @param p1 第一个点
     * @param p2 第二个点
     * @return 两个点见的欧式距离
     */
    private double distance(double[] p1,double[] p2){
        double result;
        double temp_sum = 0.0;
        for (int i = 0; i < p1.length; i++) {
            temp_sum += (p1[i]-p2[i])*(p1[i]-p2[i]);
        }
        result = Math.sqrt(temp_sum);
        return result;
    }

3.5 找到距离最小的聚类中心

//找到当前数据距离聚类中心最小的类别位置
    private int minDistance(double[] disstance){
        double min_distance = disstance[0];
        int min_index = 0;
        for (int i = 1; i < disstance.length; i++) {
            if (disstance[i]<=min_distance){
                min_distance = disstance[i];
                min_index = i;
            }
        }
        return min_index;
    }

3.6 将数据添加到对应的簇中

根据数据与聚类中心的距离,找到距离最小的聚类中心,将数据加入到该簇中。

//将当前数据元素放到聚类最近的簇中
    private void clusterSet(){
        double[] dis = new double[k];
        System.out.print("此时聚类中心:");
        show_Cneter(center);
        for (int i = 0; i < data.size(); i++) {
            System.out.println();
            for (int j = 0; j < k; j++) {
                dis[j] = distance(data.get(i),center.get(j));
            }
            System.out.print("第"+i+"个元素到中心的距离是:");
            for (int j = 0; j < dis.length; j++) {
                System.out.print(dis[j] + "\t");
            }
            int location = minDistance(dis);
            cluster.get(location).add(data.get(i));
        }
        System.out.println();
        //显示此时簇中包含的元素
        show_ClusterData(cluster);
        System.out.println();
    }

3.7 计算系统SEE

将所有数据放入对应簇后,计算当前系统的SEE值,若此时SEE值与前一次聚类所得SEE值相同,则应结束聚类。

/**
     * 求两点之间的误差平方
     * @param p1 第一个点
     * @param p2 第二个点
     * @return 两点之间的误差平方(距离)
     */
    private double errorSquare(double[] p1,double[] p2){
        double temp_sum = 0.0;
        for (int i = 0; i < p1.length; i++) {
            temp_sum += (p1[i]-p2[i])*(p1[i]-p2[i]);
        }
        return temp_sum;
    }

    //计算当前分类中所有簇中误差平方和
    private void countSEE(){
        double temp = 0;
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                //计算当前簇中的所有数据到该簇聚类中心的距离
                temp += errorSquare(cluster.get(i).get(j),center.get(i));
            }
        }
        SEE.add(temp);
    }

3.8 设置新的聚类中心(最重要操作)

当未达到最大迭代次数或者未收敛时需更新系统聚类中心,以进行下一次聚类。

//设置新的聚类中心,依照以聚类好的各簇中数据求出新的聚类中心
    private void setNewCenter(){
//        System.out.println("新的聚类中心是:");
        for (int i = 0; i < cluster.size(); i++) {
            double[] temp_center = new double[DIMENSION];
            int n = cluster.get(i).size();
            if (n != 0){
                for (int j = 0; j < n; j++) {
                    for (int index = 0; index < DIMENSION; index++) {
                        temp_center[index] += cluster.get(i).get(j)[index];
                    }
                }
                for (int j = 0; j < DIMENSION; j++) {
                    temp_center[j] = temp_center[j]/n;
                }
                //将新的聚类中心放入动态数组
                center.set(i,temp_center);
            }
//            System.out.print("["+center.get(i)[0]+","+center.get(i)[1]+"]\t");
        }
        System.out.println();
    }

3.9 迭代

让程序重复执行,直至SEE收敛或达到最大迭代次数

/**
     *  kmeans算法具体实施步骤
     */
    public void kmeans(){
        //第一步,初始化各参数
        init();
        //第二步,执行聚类操作,直到收敛或者到达迭代次数
        temp = 1; //用来记录迭代次数
        while (true){
            //将各数据放入对应簇中
            clusterSet();
            //计算对应的误差平方好
            countSEE();
            if (temp > m){
                break;
            }
            if (SEE.size()!=1){
                if (SEE.get(temp-1) - SEE.get(temp-2) == 0)
                    break;
            }
            //第三步,设置新的聚类中心,重新开始聚类
            setNewCenter();
            cluster.clear();
            initCluster();
            //让迭代次数增加
            temp++;
        }
    }

3.10 显示聚类结果

/**
     * 显示聚类最终信息
     */
    public void show(){
        System.out.print("初始聚类中心是:");
        show_Cneter(center_copy);
        System.out.println();

        System.out.print("最终聚类中心:");
        show_Cneter(center);
        System.out.println();

        System.out.println("迭代执行的次数为:"+(temp));
        System.out.print("各阶段系统误差平方和");
        for (int i = 0; i < SEE.size(); i++) {
            System.out.print(SEE.get(i)+"\t");
        }
        System.out.println();
        //显示最后系统中各簇中的元素
        show_ClusterData(cluster);
    }

四、程序测试

4.1 测试程序

现将聚类数设为4,最大迭代次数为10。

public class MyTest {
	public static void main(String[] args) {
		MyKmeans myKmeans = new MyKmeans(4, 10);
		myKmeans.kmeans();
		myKmeans.show();
	}
}

4.2 默认数据集

程序第一次操作运行结果为:
在这里插入图片描述程序最终运行结果为:
在这里插入图片描述

4.3 其他数据集

程序第一次操作结果为:
在这里插入图片描述程序最终运行结果为:
在这里插入图片描述

五、完整代码

整个程序三百多行代码,编写途中可能存在纰漏,还请大家指教!!!

import java.io.*;
import java.util.ArrayList;
import java.util.Random;

public class MyKmeans {
    private int k; //聚类数目
    private int m; //最大迭代次数
    private int dataLength; //数据集中数据的个数
    private ArrayList<double[]> data; //数据集
    private ArrayList<double[]> center; //聚类中心,结构与各数据点相同
    private ArrayList<ArrayList<double[]>> cluster; //聚类形成的簇
    private ArrayList<Double> SEE; //聚类中系统整体误差平方和
    private int temp; //用于记录最终迭代次数
    private ArrayList<double[]> center_copy; //记录初始聚类中心
    private int DIMENSION; //记录此次数据点的维度
    //空参构造函数
    public MyKmeans(){

    }
    //包含K值的构造函数
    public MyKmeans(int k){
        this.k = k;
    }
    public MyKmeans(int k,int m){
        this.k = k;
        this.m = m;
    }

    public int getM() {
        return m;
    }

    public void setM(int m) {
        this.m = m;
    }

    public int getDataLength() {
        return dataLength;
    }

    public void setDataLength(int dataLength) {
        this.dataLength = dataLength;
    }

    public ArrayList<double[]> getData() {
        return data;
    }

    public void setData(ArrayList<double[]> data) {
        this.data = data;
    }

    public ArrayList<double[]> getCenter() {
        return center;
    }

    public void setCenter(ArrayList<double[]> center) {
        this.center = center;
    }

    public ArrayList<ArrayList<double[]>> getCluster() {
        return cluster;
    }

    public void setCluster(ArrayList<ArrayList<double[]>> cluster) {
        this.cluster = cluster;
    }
    //初始化聚类,保证程序能够正常运行
    public void init(){
//        //默认情况下迭代次数为10,参考SPSS,暂未实装
//        m=10;
        //读取数据文件
        readData();
        if (data==null || data.size()==0){
            //使用系统自带的初始数据集
            initData();
        }
        dataLength = data.size();
        //判断K的取值,如果聚类数小于0则设为1类,如果大于数据集中元素个数,则设为dataLength个类
        if (k<=0){
            k=1;
        }
        if(k>dataLength){
            k=dataLength;
        }
        //初始化聚类中心,使用随机数实现
        initCenter();
        //初始化聚类结果,此时聚类结果为K个空的簇
        initCluster();
        //初始化聚类中的误差平方和
        initSEE();
    }
    //读取数据文件
    public void readData(){
        data = new ArrayList<double[]>();
        FileInputStream fileInputStream = null;
        InputStreamReader inputStreamReader = null;
        BufferedReader bufferedReader = null;
        try {
            fileInputStream = new FileInputStream(new File("./src/data.txt"));
             inputStreamReader = new InputStreamReader(fileInputStream);
             bufferedReader = new BufferedReader(inputStreamReader);
            String str = null;
            while ((str = bufferedReader.readLine()) != null){
                //获取每一行数据,创建一个一维数组暂时存储这些数据
                int len = str.split(",").length;
                double[] temp_data = new double[len];
                for (int i = 0; i < temp_data.length; i++) {
                    temp_data[i] = Double.parseDouble(str.split(",")[i]);
                }
                data.add(temp_data);
                DIMENSION = len;
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            System.out.println("未找到指定文件!");
        } catch (IOException e) {
            e.printStackTrace();
            System.out.println("打开文件出错!");
        }finally {
            try {
                bufferedReader.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            try {
                inputStreamReader.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            try {
                fileInputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
    //当没有读取本地文件时,使用系统自动的初始数据集,初始数据集为10个二维平面中的点
    public void initData(){
        DIMENSION = 2;
        data = new ArrayList<double[]>();
        double[][] default_data = new double[][]{
                {0,0},
                {1,1},
                {2,2},
                {3,3},
                {4,4},
                {5,5},
                {6,6},
                {7,7},
                {8,8},
                {9,9}
        };
        for (int i = 0; i < default_data.length; i++) {
            data.add(default_data[i]);
        }
    }

    //显示聚类中心
    public void show_Cneter(ArrayList<double[]> center){
        for (int i  = 0; i < center.size(); i++) {
            System.out.print("[");
            for (int j = 0; j < DIMENSION; j++) {
                System.out.print(center.get(i)[j]);
                if (j!= DIMENSION-1)
                    System.out.print(",");
            }
            System.out.print("]\t");
        }
    }

    //初始化聚类中心,使用随机数生成
    public void initCenter(){
        center = new ArrayList<double[]>();
        for (int i = 0; i <k; i++) {
            Random random = new Random();
            double[] randoms_center;
            //随机指定数据点作为初始中心,可换用其他方法生成随机初始中心
            int index = random.nextInt(dataLength);
            randoms_center = data.get(index);
            center.add(randoms_center);
        }
        //保存初始聚类中心的副本
        //center_copy = center; 此种方法万万不可写
        center_copy = new ArrayList<>();
        center_copy.addAll(center);
        System.out.print("初始聚类中心是:");
        show_Cneter(center_copy);
        System.out.println();
        System.out.println();
    }

    //初始化聚类结果,此时聚类中包含K个孔的簇
    public void initCluster(){
        cluster = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            cluster.add(new ArrayList<>());
        }
    }

    //显示每个簇中的元素
    public void show_ClusterData(ArrayList<ArrayList<double[]>> cluster){
        for (int i = 0; i < cluster.size(); i++) {
            System.out.print("类别"+(i+1)+"包含元素:");
            for (int j = 0; j < cluster.get(i).size(); j++) {
                System.out.print("[");
                for (int index = 0;index < DIMENSION; index++){
                    System.out.print(cluster.get(i).get(j)[index]);
                    if (index!= DIMENSION-1)
                        System.out.print(",");
                }
                System.out.print("]\t");
            }
            System.out.println();
        }
    }

    //初始化聚类中的误差平方和
    public void initSEE(){
        SEE = new ArrayList<>();
    }

    /**
     * 计算两个点之间的距离
     * @param p1 第一个点
     * @param p2 第二个点
     * @return 两个点见的欧式距离
     */
    private double distance(double[] p1,double[] p2){
        double result;
        double temp_sum = 0.0;
        for (int i = 0; i < p1.length; i++) {
            temp_sum += (p1[i]-p2[i])*(p1[i]-p2[i]);
        }
        result = Math.sqrt(temp_sum);
        return result;
    }

    //找到当前数据距离聚类中心最小的类别位置
    private int minDistance(double[] disstance){
        double min_distance = disstance[0];
        int min_index = 0;
        for (int i = 1; i < disstance.length; i++) {
            if (disstance[i]<=min_distance){
                min_distance = disstance[i];
                min_index = i;
            }
        }
        return min_index;
    }
    //将当前数据元素放到聚类最近的簇中
    private void clusterSet(){
        double[] dis = new double[k];
        System.out.print("此时聚类中心:");
        show_Cneter(center);
        for (int i = 0; i < data.size(); i++) {
            System.out.println();
            for (int j = 0; j < k; j++) {
                dis[j] = distance(data.get(i),center.get(j));
            }
            System.out.print("第"+i+"个元素到中心的距离是:");
            for (int j = 0; j < dis.length; j++) {
                System.out.print(dis[j] + "\t");
            }
            int location = minDistance(dis);
            cluster.get(location).add(data.get(i));
        }
        System.out.println();
        //显示此时簇中包含的元素
        show_ClusterData(cluster);
        System.out.println();
    }

    /**
     * 求两点之间的误差平方
     * @param p1 第一个点
     * @param p2 第二个点
     * @return 两点之间的误差平方(距离)
     */
    private double errorSquare(double[] p1,double[] p2){
        double temp_sum = 0.0;
        for (int i = 0; i < p1.length; i++) {
            temp_sum += (p1[i]-p2[i])*(p1[i]-p2[i]);
        }
        return temp_sum;
    }

    //计算当前分类中所有簇中误差平方和
    private void countSEE(){
        double temp = 0;
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                //计算当前簇中的所有数据到该簇聚类中心的距离
                temp += errorSquare(cluster.get(i).get(j),center.get(i));
            }
        }
        SEE.add(temp);
    }

    //设置新的聚类中心,依照以聚类好的各簇中数据求出新的聚类中心
    private void setNewCenter(){
//        System.out.println("新的聚类中心是:");
        for (int i = 0; i < cluster.size(); i++) {
            double[] temp_center = new double[DIMENSION];
            int n = cluster.get(i).size();
            if (n != 0){
                for (int j = 0; j < n; j++) {
                    for (int index = 0; index < DIMENSION; index++) {
                        temp_center[index] += cluster.get(i).get(j)[index];
                    }
                }
                for (int j = 0; j < DIMENSION; j++) {
                    temp_center[j] = temp_center[j]/n;
                }
                //将新的聚类中心放入动态数组
                center.set(i,temp_center);
            }
//            System.out.print("["+center.get(i)[0]+","+center.get(i)[1]+"]\t");
        }
        System.out.println();
    }

    /**
     * 显示聚类最终信息
     */
    public void show(){
        System.out.print("初始聚类中心是:");
        show_Cneter(center_copy);
        System.out.println();

        System.out.print("最终聚类中心:");
        show_Cneter(center);
        System.out.println();

        System.out.println("迭代执行的次数为:"+(temp));
        System.out.print("各阶段系统误差平方和");
        for (int i = 0; i < SEE.size(); i++) {
            System.out.print(SEE.get(i)+"\t");
        }
        System.out.println();
        //显示最后系统中各簇中的元素
        show_ClusterData(cluster);
    }

    /**
     *  kmeans算法具体实施步骤
     */
    public void kmeans(){
        //第一步,初始化各参数
        init();
        //第二步,执行聚类操作,直到收敛或者到达迭代次数
        temp = 1; //用来记录迭代次数
        while (true){
            //将各数据放入对应簇中
            clusterSet();
            //计算对应的误差平方好
            countSEE();
            if (temp > m){
                break;
            }
            if (SEE.size()!=1){
                if (SEE.get(temp-1) - SEE.get(temp-2) == 0)
                    break;
            }
            //第三步,设置新的聚类中心,重新开始聚类
            setNewCenter();
            cluster.clear();
            initCluster();
            //让迭代次数增加
            temp++;
        }
    }
}

  • 34
    点赞
  • 65
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论
在PyTorch中,要冻结某层参数,即保持其权重在训练过程中不发生更新。这种操作通常在迁移学习或固定特定的层的场景下使用,以便保留已经学到的知识。 要冻结某层参数,可以通过以下步骤实现: 1. 首先,加载模型并查看模型的结构。通过打印模型就可以看到每一层的名称以及对应的索引。 2. 在训练之前,确定需要冻结的层。可以通过模型的参数名称或索引来定位到具体的层。 3. 使用`requires_grad_()`函数来冻结参数,将需要冻结的层的`requires_grad`属性设置为False。这样,在反向传播过程中,这些参数的梯度就不会进行更新了。 4. 在训练过程中,只对其他未冻结的层进行梯度更新。 下面是一个简单的示例代码,演示如何冻结某层参数: ```python import torch import torch.nn as nn # 加载模型并创建优化器 model = torchvision.models.resnet18(pretrained=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 查看模型结构 print(model) # 冻结某层参数 # 可以通过模型的named_parameters()函数获取每一层的名称和参数 # 这里以冻结ResNet的第4个卷积层参数为例 for name, param in model.named_parameters(): if 'layer4' in name: # 可根据具体需求来决定冻结哪些层 param.requires_grad_(False) # 训练过程 for inputs, labels in dataloader: outputs = model(inputs) loss = loss_func(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ``` 通过以上步骤,我们可以实现冻结某层参数的操作。这样,在训练过程中,被冻结的层的参数将不会更新,从而保持其固定的权重。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进击的墨菲特

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值