为更好理解聚类算法,从网上找现成代码来理解,发现了一个Java自身的ML库,链接:http://java-ml.sourceforge.net/

有兴趣可以下载来看看源码,理解基础ML算法。对于DBSCAN算法,从网上找到一个Java实现的,主要是用来理解其算法过程。参考代码如下:

1、Point类,数据对象

package sk.cluster;

public class Point {

private double x;//坐标x轴

private double y;//坐标y轴

private boolean isVisit;//是佛访问标记

private int cluster;//所属簇类

private boolean isNoised;//是否是噪音数据

public Point(double x,double y) {

this.x = x;

this.y = y;

this.isVisit = false;

this.cluster = 0;

this.isNoised = false;

}

public double getDistance(Point point) {//计算两点间距离

return Math.sqrt((x-point.x)*(x-point.x)+(y-point.y)*(y-point.y));

}

public void setX(double x) {

this.x = x;

}

public double getX() {

return x;

}

public void setY(double y) {

this.y = y;

}

public double getY() {

return y;

}

public void setVisit(boolean isVisit) {

this.isVisit = isVisit;

}

public boolean getVisit() {

return isVisit;

}

public int getCluster() {

return cluster;

}

public void setNoised(boolean isNoised) {

this.isNoised = isNoised;

}

public void setCluster(int cluster) {

this.cluster = cluster;

}

public boolean getNoised() {

return this.isNoised;

}

@Override

public String toString() {

return x+" "+y+" "+cluster+" "+(isNoised?1:0);

}

}

2、Data类,数据集

package sk.cluster;

import java.io.*;

import java.text.DecimalFormat;

import java.text.NumberFormat;

import java.util.ArrayList;

import java.util.Random;

public class Data {

private static DecimalFormat df=(DecimalFormat) NumberFormat.getInstance();

//随机生成数据

public static ArrayList generateSinData(int size) {

ArrayList points = new ArrayList(size);

Random rd = new Random(size);

for (int i=0;i

double x = format(Math.PI / (size / 2) * (i + 1));

double y = format(Math.sin(x)) ;

points.add(new Point(x,y));

}

for (int i=0;i

double x = format(1.5 + Math.PI / (size/2) * (i+1));

double y = format(Math.cos(x));

points.add(new Point(x,y));

}

return points;

}

//输入指定数据

public static ArrayList generateSpecialData() {

ArrayList points = new ArrayList();

points.add(new Point(2,2));

points.add(new Point(3,1));

points.add(new Point(3,4));

points.add(new Point(3,14));

points.add(new Point(5,3));

points.add(new Point(8,3));

points.add(new Point(8,6));

points.add(new Point(9,8));

points.add(new Point(10,4));

points.add(new Point(10,7));

points.add(new Point(10,10));

points.add(new Point(10,14));

points.add(new Point(11,13));

points.add(new Point(12,7));

points.add(new Point(12,15));

points.add(new Point(14,7));

points.add(new Point(14,9));

points.add(new Point(14,15));

points.add(new Point(15,8));

return points;

}

//获取文件数据

public static ArrayList getData(String sourcePath) {

ArrayList points = new ArrayList();

File fileIn = new File(sourcePath);

try {

BufferedReader br = new BufferedReader(new FileReader(fileIn));

String line = null;

line = br.readLine();

while (line != null) {

Double x = Double.parseDouble(line.split(",")[3]);

Double y = Double.parseDouble(line.split(",")[4]);

points.add(new Point(x, y));

line = br.readLine();

}

br.close();

} catch (IOException e) {

e.printStackTrace();

}

return points;

}

//输出到文件

public static void writeData(ArrayList points,String path) {

try {

BufferedWriter bw = new BufferedWriter(new FileWriter(path));

for (Point point:points) {

bw.write(point.toString()+"\n");

}

bw.close();

} catch (IOException e) {

e.printStackTrace();

}

}

private static double format(double x) {

return Double.valueOf(df.format(x));

}

}

3、DBSCAN类,实现DBSCAN算法

package sk.cluster;

import java.util.ArrayList;

public class DBScan {

private double radius;

private int minPts;

public DBScan(double radius,int minPts) {

this.radius = radius;//领域半径参数

this.minPts = minPts;//领域密度值,该领域内有多少个样本

}

public void process(ArrayList points) {

int size = points.size();

int idx = 0;

int cluster = 1;

while (idx

Point p = points.get(idx++);

//choose an unvisited point

if (!p.getVisit()) {

p.setVisit(true);//set visited

ArrayList adjacentPoints = getAdjacentPoints(p, points);//计算两点距离,看是否在领域内

//set the point which adjacent points less than minPts noised

if (adjacentPoints != null && adjacentPoints.size() < minPts) {

p.setNoised(true);//噪音数据

} else {//建立该点作为领域核心对象

p.setCluster(cluster);

for (int i = 0; i < adjacentPoints.size(); i++) {

Point adjacentPoint = adjacentPoints.get(i);//领域内的样本

//only check unvisited point, cause only unvisited have the chance to add new adjacent points

if (!adjacentPoint.getVisit()) {

adjacentPoint.setVisit(true);

ArrayList adjacentAdjacentPoints = getAdjacentPoints(adjacentPoint, points);

//add point which adjacent points not less than minPts noised

if (adjacentAdjacentPoints != null && adjacentAdjacentPoints.size() >= minPts) {

//adjacentPoints.addAll(adjacentAdjacentPoints);

for (Point pp : adjacentAdjacentPoints){

if (!adjacentPoints.contains(pp)){

adjacentPoints.add(pp);

}

}

}

}

//add point which doest not belong to any cluster

if (adjacentPoint.getCluster() == 0) {

adjacentPoint.setCluster(cluster);

//set point which marked noised before non-noised

if (adjacentPoint.getNoised()) {

adjacentPoint.setNoised(false);

}

}

}

cluster++;

}

}

if (idx%1000==0) {

System.out.println(idx);

}

}

}

private ArrayList getAdjacentPoints(Point centerPoint,ArrayList points) {

ArrayList adjacentPoints = new ArrayList();

for (Point p:points) {

//include centerPoint itself

double distance = centerPoint.getDistance(p);

if (distance<=radius) {

adjacentPoints.add(p);

}

}

return adjacentPoints;

}

}

/*

##DBScan算法流程图

算法:DBScan,基于密度的聚类算法

输入:

D:一个包含n个数据的数据集

r:半径参数

minPts:领域密度阈值

输出:基于密度的聚类集合

标记D中所有的点为unvisted

for each p in D

if p.visit = unvisted

找出与点p距离不大于r的所有点集合N

If N.size() < minPts

标记点p为噪声点

Else

for each p' in N

If p'.visit == unvisted

找出与点p距离不大于r的所有点集合N'

If N'.size()>=minPts

将集合N'加入集合N中去

End if

Else

If p'未被聚到某个簇

将p'聚到当前簇

If p'被标记为噪声点

将p'取消标记为噪声点

End If

End If

End If

End for

End if

End if

End for

*/

4、client测试类

package sk.cluster;

import java.util.ArrayList;

public class Client {

public static void main(String[] args) {

ArrayList points = Data.generateSinData(200);//随机生成200个point

DBScan dbScan = new DBScan(0.6,4);//r:领域半径参数 ,minPts领域密度阈值,密度值

//ArrayList points = Data.generateSpecialData();

//ArrayList points = Data.getData("D:\\tmp\\testData.txt");

//DBScan dbScan = new DBScan(0.1,1000);

dbScan.process(points);

for (Point p:points) {

System.out.println(p);

}

Data.writeData(points,"D:\\tmp\\data.txt");

}

}

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐