xref: /freebsd/contrib/arm-optimized-routines/math/tools/plot.py (revision 072a4ba82a01476eaee33781ccd241033eefcf0b)
131914882SAlex Richardson#!/usr/bin/python
231914882SAlex Richardson
331914882SAlex Richardson# ULP error plot tool.
431914882SAlex Richardson#
531914882SAlex Richardson# Copyright (c) 2019, Arm Limited.
6*072a4ba8SAndrew Turner# SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
731914882SAlex Richardson
831914882SAlex Richardsonimport numpy as np
931914882SAlex Richardsonimport matplotlib.pyplot as plt
1031914882SAlex Richardsonimport sys
1131914882SAlex Richardsonimport re
1231914882SAlex Richardson
1331914882SAlex Richardson# example usage:
1431914882SAlex Richardson# build/bin/ulp -e .0001 log 0.5 2.0 2345678 | math/tools/plot.py
1531914882SAlex Richardson
1631914882SAlex Richardsondef fhex(s):
1731914882SAlex Richardson	return float.fromhex(s)
1831914882SAlex Richardson
1931914882SAlex Richardsondef parse(f):
2031914882SAlex Richardson	xs = []
2131914882SAlex Richardson	gs = []
2231914882SAlex Richardson	ys = []
2331914882SAlex Richardson	es = []
2431914882SAlex Richardson	# Has to match the format used in ulp.c
2531914882SAlex Richardson	r = re.compile(r'[^ (]+\(([^ )]*)\) got ([^ ]+) want ([^ ]+) [^ ]+ ulp err ([^ ]+)')
2631914882SAlex Richardson	for line in f:
2731914882SAlex Richardson		m = r.match(line)
2831914882SAlex Richardson		if m:
2931914882SAlex Richardson			x = fhex(m.group(1))
3031914882SAlex Richardson			g = fhex(m.group(2))
3131914882SAlex Richardson			y = fhex(m.group(3))
3231914882SAlex Richardson			e = float(m.group(4))
3331914882SAlex Richardson			xs.append(x)
3431914882SAlex Richardson			gs.append(g)
3531914882SAlex Richardson			ys.append(y)
3631914882SAlex Richardson			es.append(e)
3731914882SAlex Richardson		elif line.startswith('PASS') or line.startswith('FAIL'):
3831914882SAlex Richardson			# Print the summary line
3931914882SAlex Richardson			print(line)
4031914882SAlex Richardson	return xs, gs, ys, es
4131914882SAlex Richardson
4231914882SAlex Richardsondef plot(xs, gs, ys, es):
4331914882SAlex Richardson	if len(xs) < 2:
4431914882SAlex Richardson		print('not enough samples')
4531914882SAlex Richardson		return
4631914882SAlex Richardson	a = min(xs)
4731914882SAlex Richardson	b = max(xs)
4831914882SAlex Richardson	fig, (ax0,ax1) = plt.subplots(nrows=2)
4931914882SAlex Richardson	es = np.abs(es) # ignore the sign
5031914882SAlex Richardson	emax = max(es)
5131914882SAlex Richardson	ax0.text(a+(b-a)*0.7, emax*0.8, '%s\n%g'%(emax.hex(),emax))
5231914882SAlex Richardson	ax0.plot(xs,es,'r.')
5331914882SAlex Richardson	ax0.grid()
5431914882SAlex Richardson	ax1.plot(xs,ys,'r.',label='want')
5531914882SAlex Richardson	ax1.plot(xs,gs,'b.',label='got')
5631914882SAlex Richardson	ax1.grid()
5731914882SAlex Richardson	ax1.legend()
5831914882SAlex Richardson	plt.show()
5931914882SAlex Richardson
6031914882SAlex Richardsonxs, gs, ys, es = parse(sys.stdin)
6131914882SAlex Richardsonplot(xs, gs, ys, es)
62