Skip to content

KMeansClustering

Repository source: KMeansClustering

Description

This example clusters 3D points using the KMeans algorithm. The points are assigned to a cluster by creating an array with each point's cluster id.

Other languages

See (Cxx)

Question

If you have a question about this example, please use the VTK Discourse Forum

Code

KMeansClustering.py

#!/usr/bin/env python3

# noinspection PyUnresolvedReferences
import vtkmodules.vtkInteractionStyle
# noinspection PyUnresolvedReferences
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import (
    vtkIntArray,
    vtkDoubleArray,
    vtkPoints
)
from vtkmodules.vtkCommonDataModel import (
    vtkPolyData,
    vtkTable
)
from vtkmodules.vtkFiltersGeneral import vtkVertexGlyphFilter
from vtkmodules.vtkFiltersStatistics import (
    vtkKMeansStatistics,
    vtkStatisticsAlgorithm
)
from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkPolyDataMapper,
    vtkRenderer,
    vtkRenderWindow,
    vtkRenderWindowInteractor
)


def main():
    colors = vtkNamedColors()

    # Create 2 clusters, one near (0,0,0) and the other near (3,3,3).
    points = vtkPoints()

    points.InsertNextPoint(0.0, 0.0, 0.0)
    points.InsertNextPoint(3.0, 3.0, 3.0)
    points.InsertNextPoint(0.1, 0.1, 0.1)
    points.InsertNextPoint(3.1, 3.1, 3.1)
    points.InsertNextPoint(0.2, 0.2, 0.2)
    points.InsertNextPoint(3.2, 3.2, 3.2)

    # Get the points into the format needed for KMeans.
    input_data = vtkTable()

    for c in range(0, 3):
        col_name = f'coord {c:0d}'
        double_array = vtkDoubleArray(number_of_components=1, name=col_name, number_of_tuples=points.number_of_points)

        for r in range(0, points.number_of_points):
            p = [0.0] * 3
            points.GetPoint(r, p)

            double_array.SetValue(r, p[c])

        input_data.AddColumn(double_array)

    k_means_statistics = vtkKMeansStatistics(input_data=(vtkStatisticsAlgorithm.INPUT_DATA, input_data),
                                           assess_option=True, default_number_of_clusters=2)
    k_means_statistics.SetColumnStatus(input_data.GetColumnName(0), 1)
    k_means_statistics.SetColumnStatus(input_data.GetColumnName(1), 1)
    k_means_statistics.SetColumnStatus(input_data.GetColumnName(2), 1)
    k_means_statistics.RequestSelectedColumns()
    k_means_statistics.update()

    # Display the results.
    k_means_statistics.output.Dump()

    cluster_array = vtkIntArray(number_of_components=1, name='ClusterId')

    for r in range(0, k_means_statistics.output.number_of_rows):
        v = k_means_statistics.GetOutput().GetValue(r, k_means_statistics.GetOutput().GetNumberOfColumns() - 1)
        print(f'Point {r} is in cluster {v.ToInt()}')
        cluster_array.InsertNextValue(v.ToInt())

    # Output the cluster centers.
    output_meta_ds = k_means_statistics.GetOutputDataObject(vtkStatisticsAlgorithm.OUTPUT_MODEL)
    output_meta = output_meta_ds.GetBlock(0)
    coord0 = output_meta.GetColumnByName('coord 0')
    coord1 = output_meta.GetColumnByName('coord 1')
    coord2 = output_meta.GetColumnByName('coord 2')
    print(f'Cluster centers:')
    number_of_tuples = coord0.number_of_tuples
    for i in range(0, number_of_tuples):
        center = [coord0.GetValue(i), coord1.GetValue(i), coord2.GetValue(i)]
        print(f'Cluster {i}: ({fmt_floats(center)})')

    poly_data = vtkPolyData(points=points)
    poly_data.point_data.SetScalars(cluster_array)

    # Display
    glyph_filter = vtkVertexGlyphFilter(input_data=poly_data)

    # Create a mapper and actor.
    mapper = vtkPolyDataMapper()
    glyph_filter >> mapper

    actor = vtkActor(mapper=mapper)
    actor.property.point_size = 10
    actor.property.render_points_as_spheres = True

    # Create a renderer, render window, and interactor.
    renderer = vtkRenderer(background=colors.GetColor3d('OliveDrab'))
    render_window = vtkRenderWindow(size=(600, 600), window_name='KMeansClustering')
    render_window.AddRenderer(renderer)
    render_window_interactor = vtkRenderWindowInteractor()
    render_window_interactor.render_window = render_window

    # Add the actor to the scene
    renderer.AddActor(actor)

    style = vtkInteractorStyleTrackballCamera()
    render_window_interactor.interactor_style = style

    # Render and interact.
    render_window.Render()
    render_window_interactor.Start()


def fmt_floats(v, w=0, d=6, pt='g'):
    """
    Pretty print a list or tuple of floats.

    :param v: The list or tuple of floats.
    :param w: Total width of the field.
    :param d: The number of decimal places.
    :param pt: The presentation type, 'f', 'g' or 'e'.
    :return: A string.
    """
    pt = pt.lower()
    if pt not in ['f', 'g', 'e']:
        pt = 'f'
    return ', '.join([f'{element:{w}.{d}{pt}}' for element in v])


if __name__ == '__main__':
    main()