diff --git a/src/semantic_code_search/cli.py b/src/semantic_code_search/cli.py index 0b3c9a9..93ec84f 100644 --- a/src/semantic_code_search/cli.py +++ b/src/semantic_code_search/cli.py @@ -7,6 +7,7 @@ from semantic_code_search.embed import do_embed from semantic_code_search.query import do_query +from semantic_code_search.cluster import do_cluster def git_root(path=None): @@ -37,6 +38,11 @@ def query_func(args): do_query(args, model) +def cluster_func(args): + model = SentenceTransformer(args.model_name_or_path) + do_cluster(args, model) + + def main(): parser = argparse.ArgumentParser( prog='sem', description='Search your codebase using natural language') @@ -55,6 +61,16 @@ def main(): required=False, default=5, help='Number of results to return') parser.add_argument('-e', '--editor', choices=[ 'vscode', 'vim'], default='vscode', required=False, help='Editor to open selected result in') + parser.add_argument('-c', '--cluster', action='store_true', default=False, + required=False, help='Generate clusters of related functions and methods') + parser.add_argument('--cluster-max-distance', metavar='THRESHOLD', type=float, default=0.2, required=False, + help='How close functions need to be to one another to be clustered. Distance 0 means that the code is identical, smaller values (e.g. 0.2, 0.3) are stricter and result in fewer matches ') + parser.add_argument('--cluster-min-lines', metavar='SIZE', type=int, default=0, required=False, + help='Ignore clusters with code snippets smaller than this size (lines of code). Use this if you are not interested in smaller duplications (eg. one liners)') + parser.add_argument('--cluster-min-cluster-size', metavar='SIZE', type=int, default=2, required=False, + help='Ignore clusters smaller than this size. Use this if you want to find code that is similar and repeated many times (e.g. >5)') + parser.add_argument('--cluster-ignore-identincal', action='store_true', default=True, + required=False, help='Ignore identical code / exact duplicates (where distance is 0)') parser.set_defaults(func=query_func) parser.add_argument('query_text', nargs=argparse.REMAINDER) @@ -62,6 +78,8 @@ def main(): if args.embed: embed_func(args) + elif args.cluster: + cluster_func(args) else: query_func(args) diff --git a/src/semantic_code_search/cluster.py b/src/semantic_code_search/cluster.py new file mode 100644 index 0000000..2a731f5 --- /dev/null +++ b/src/semantic_code_search/cluster.py @@ -0,0 +1,84 @@ +import gzip +import os +import pickle +from semantic_code_search.embed import do_embed +from sklearn.cluster import AgglomerativeClustering +import numpy as np +from textwrap import indent + + +def _get_clusters(dataset, distance_threshold): + embeddings = dataset.get('embeddings') + # Normalize the embeddings to unit length + embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) + dataset['embeddings'] = embeddings + + clustering_model = AgglomerativeClustering( + n_clusters=None, + distance_threshold=distance_threshold, + compute_distances=True, + ) + clustering_model.fit(embeddings) + cluster_assignment = clustering_model.labels_ + cluster_distances = clustering_model.distances_ + cluster_children = clustering_model.children_ + + clustered_functions = {} + for idx, cluster_id in enumerate(cluster_assignment): + if cluster_id not in clustered_functions: + clustered_functions[cluster_id] = [] + + ds_entry = dataset.get('functions')[idx] + ds_entry['idx'] = idx + + clustered_functions[cluster_id].append(ds_entry) + + # filter out clusters with only one function + clusters = [] + for cluster_id, functions in clustered_functions.items(): + if len(functions) > 1: + fx_idx = functions[0].get('idx') + distances = [] + for f in functions[1:]: + f_idx = f.get('idx') + for i, cc in enumerate(cluster_children): + if cc.tolist() == [fx_idx, f_idx]: + distances.append(cluster_distances[i]) + avg_distance = sum(distances) / \ + len(distances) if len(distances) > 0 else 0 + clusters.append( + {'avg_distance': avg_distance, 'functions': functions}) + + return clusters + + +def do_cluster(args, model): + if not os.path.isfile(args.path_to_repo + '/' + '.embeddings'): + print('Embeddings not found in {}. Generating embeddings now.'.format( + args.path_to_repo)) + do_embed(args, model) + + with gzip.open(args.path_to_repo + '/' + '.embeddings', 'r') as f: + dataset = pickle.loads(f.read()) + if dataset.get('model_name') != args.model_name_or_path: + print('Model name mismatch. Regenerating embeddings.') + dataset = do_embed(args, model) + clusters = _get_clusters(dataset, args.cluster_max_distance) + + filtered_clusters = [] + for c in (clusters): + if args.cluster_ignore_identincal and c.get('avg_distance') == 0: + continue + if any([len(f.get('text').split('\n')) <= args.cluster_min_lines for f in c.get('functions')]): + continue + if len(c.get('functions')) < args.cluster_min_cluster_size: + continue + filtered_clusters.append(c) + + for i, c in enumerate(filtered_clusters): + print('Cluster #{}: avg_distance: {:.3} ================================================\n'.format( + i, c.get('avg_distance'))) + # print('avg_distance:', c.get('avg_distance')) + for f in c.get('functions'): + print(indent(f.get('file'), ' ') + ':' + str(f.get('line'))) + print(indent(f.get('text'), ' ') + '\n') diff --git a/src/semantic_code_search/query.py b/src/semantic_code_search/query.py index c13f0b0..d5e113f 100644 --- a/src/semantic_code_search/query.py +++ b/src/semantic_code_search/query.py @@ -13,7 +13,7 @@ def _search(query_embedding, corpus_embeddings, functions, k=5, file_extension=None): # TODO: filtering by file extension cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] - top_results = torch.topk(cos_scores, k=min(k, len(cos_scores) -1), sorted=True) + top_results = torch.topk(cos_scores, k=min(k, len(cos_scores)), sorted=True) out = [] for score, idx in zip(top_results[0], top_results[1]): out.append((score, functions[idx]))