package com.txq.kmeans;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* kMeans聚類演算法
* @author TongXueQiang
* @date 2017/09/09
*/
public class Kmeans {
private DecimalFormat df = new DecimalFormat("#####.00");
public Kmeans_data data = null;
// feature,樣本名稱和索引映射
private Map<String, Integer> identifier = new HashMap<String, Integer>();
private Map<Integer, String> iden0 = new HashMap<Integer, String>();
private ClusterModel model = new ClusterModel();
/**
* 檔案到矩陣的映射
* @param path
* @return
* @throws Exception
*/
public double[][] fileToMatrix(String path) throws Exception {
List<String> contents = new ArrayList<String>();
model.identifier = identifier;
model.iden0 = iden0;
FileInputStream file = null;
InputStreamReader inputFileReader = null;
BufferedReader reader = null;
String str = null;
int rows = 0;
int dim = 0;
try {
file = new FileInputStream(path);
inputFileReader = new InputStreamReader(file, "utf-8");
reader = new BufferedReader(inputFileReader);
// 一次讀入一行,直到讀入null為檔案結束
while ((str = reader.readLine()) != null) {
contents.add(str);
++rows;
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
return null;
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
String[] strs = contents.get(0).split(":");
dim = strs[0].split(" ").length;
double[][] da = new double[rows][dim];
for (int j = 0; j < contents.size(); j++) {
strs = contents.get(j).split(":");
identifier.put(strs[1], j);
iden0.put(j, strs[1]);
String[] feature = strs[0].split(" ");
for (int i = 0; i < dim; i++) {
da[j][i] = Double.parseDouble(feature[i]);
}
}
return da;
}
/**
* 清零操作
* @param matrix
* @param highDim
* @param lowDim
*/
private void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
for (int i = 0; i < highDim; i++) {
for (int j = 0; j < lowDim; j++) {
matrix[i][j] = 0;
}
}
}
/**
* 聚類中心拷貝
* @param dests
* @param sources
* @param highDim
* @param lowDim
*/
private void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
for (int i = 0; i < highDim; i++) {
for (int j = 0; j < lowDim; j++) {
dests[i][j] = sources[i][j];
}
}
}
/**
* 更新聚類中心
* @param k
* @param data
*/
private void updateCenters(int k, Kmeans_data data) {
double[][] centers = data.centers;
setDouble2Zero(centers, k, data.dim);
int[] labels = model.labels;
int[] centerCounts = model.centerCounts;
for (int i = 0; i < data.dim; i++) {
for (int j = 0; j < data.length; j++) {
centers[labels[j]][i] += data.data[j][i];
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < data.dim; j++) {
centers[i][j] = centers[i][j] / centerCounts[i];
}
}
}
/**
* 計算歐氏距離
* @param pa
* @param pb
* @param dim
* @return
*/
public double dist(double[] pa, double[] pb, int dim) {
double rv = 0;
for (int i = 0; i < dim; i++) {
double temp = pa[i] - pb[i];
temp = temp * temp;
rv += temp;
}
return Math.sqrt(rv);
}
/**
* 樣本訓練,需要人為設定k值(聚類中心數目)
* @param k
* @param data
* @return
* @throws Exception
*/
public ClusterModel train(String path, int k) throws Exception {
double[][] matrix = fileToMatrix(path);
data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
return train(k, new Kmeans_param());
}
/**
* 樣本訓練(系統默認最優聚類中心數目)
* @param data
* @return
* @throws Exception
*/
public ClusterModel train(String path) throws Exception {
double[][] matrix = fileToMatrix(path);
data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
return train(new Kmeans_param());
}
private ClusterModel train(Kmeans_param param) {
int k = Kmeans_param.K;
// 首先進行資料歸一化處理
normalize(data);
// 計算第一個樣本和后面的所有樣本的歐氏距離,存入list中然后計算均值,作為聚類中心選取的依據
List<Double> dists = new ArrayList<Double>();
for (int i = 1; i < data.length; i++) {
dists.add(dist(data.data[0], data.data[i], data.dim));
}
param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists) + Collections.min(dists)) / 2));
double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance
: Kmeans_param.MIN_EuclideanDistance;
int centerIndexes[] = new int[k];// 收集聚類中心索引的陣列
int countCenter = 0;// 動態表示中心的數目
int count = 0;// 計數器
centerIndexes[0] = 0;
countCenter++;
for (int i = 1; i < data.length; i++) {
for (int j = 0; j < countCenter; j++) {
if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) {
count++;
}
}
if (count == countCenter) {
centerIndexes[countCenter++] = i;
}
count = 0;
}
double[][] centers = new double[countCenter][data.dim]; // 聚類中心
data.centers = centers;
int[] centerCounts = new int[countCenter]; // 聚類中心的樣本個數
model.centerCounts = centerCounts;
Arrays.fill(centerCounts, 0);
int[] labels = new int[data.length]; // 樣本的類別
model.labels = labels;
double[][] oldCenters = new double[countCenter][data.dim]; // 存盤舊的聚類中心
// 給聚類中心賦值
for (int i = 0; i < countCenter; i++) {
int m = centerIndexes[i];
for (int j = 0; j < data.dim; j++) {
centers[i][j] = data.data[m][j];
}
}
// 給最初始的聚類中心賦值
model.originalCenters = new double[countCenter][data.dim];
for (int i = 0; i < countCenter; i++) {
for (int j = 0; j < data.dim; j++) {
model.originalCenters[i][j] = centers[i][j];
}
}
//初始聚類
for (int i = 0; i < data.length; i++) {
double minDist = dist(data.data[i], centers[0], data.dim);
int label = 0;
for (int j = 1; j < countCenter; j++) {
double tempDist = dist(data.data[i], centers[j], data.dim);
if (tempDist < minDist) {
minDist = tempDist;
label = j;
}
}
labels[i] = label;
centerCounts[label]++;
}
updateCenters(countCenter, data);
copyCenters(oldCenters, centers, countCenter, data.dim);
// 迭代預處理
int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
int attempts = 1;
double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
double criteriaBreakCondition = 0;
boolean[] flags = new boolean[k]; // 用來表示聚類中心是否發生變化
// 迭代
iterate: while (attempts < maxAttempts) { // 迭代次數不超過最大值,最大中心改變數不超過閾值
for (int i = 0; i < countCenter; i++) { // 初始化中心點"是否被修改過"標記
flags[i] = false;
}
for (int i = 0; i < data.length; i++) {
double minDist = dist(data.data[i], centers[0], data.dim);
int label = 0;
for (int j = 1; j < countCenter; j++) {
double tempDist = dist(data.data[i], centers[j], data.dim);
if (tempDist < minDist) {
minDist = tempDist;
label = j;
}
}
if (label != labels[i]) { // 如果當前點被聚類到新的類別則做更新
int oldLabel = labels[i];
labels[i] = label;
centerCounts[oldLabel]--;
centerCounts[label]++;
flags[oldLabel] = true;
flags[label] = true;
}
}
updateCenters(countCenter, data);
attempts++;
// 計算被修改過的中心點最大修改量是否超過閾值
double maxDist = 0;
for (int i = 0; i < countCenter; i++) {
if (flags[i]) {
double tempDist = dist(centers[i], oldCenters[i], data.dim);
if (maxDist < tempDist) {
maxDist = tempDist;
}
for (int j = 0; j < data.dim; j++) { // 更新oldCenter
oldCenters[i][j] = centers[i][j];
oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
}
}
}
if (maxDist < criteria) {
criteriaBreakCondition = maxDist;
break iterate;
}
}
// 把結果存盤到ClusterModel中
ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, countCenter,
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/84861.html
標籤:Java相關
上一篇:SpringBoot依賴配置
