次のコードがあります。最適な属性を選択するための属性に数値がない場合、正しく機能します。ただし、年齢属性などの属性に数値がある場合、コードを修正して機能させる方法がわかりません。
from arff_v2 import *
import math
def entropy(data, target_attr):
"""
Calculates the entropy of the given data set for the target attribute.
"""
entropy=0
pos_neg = {}
# Calculate the frequency of each of the values in the target attr
for record in data:
if (pos_neg.has_key(record[target_attr])):
pos_neg[record[target_attr]] += 1.0
#print record[target_attr] #for self-testing
#print target_attr #for self-testing
#print pos_neg #for self-testing
else:
pos_neg[record[target_attr]] = 1.0
# Calculate the entropy of the data for the target attribute
for np in pos_neg.values():
entropy += (-np/len(data)) * math.log(np/len(data), 2)
return entropy
relation, comments, attrs, data = readArff('heart_train.arff')
#print attrs #for self-testing
#print len(attrs) #for self-testing
ep=entropy(data,13)
#neg_pos={} #for self-testing
#for record in data: #for self-testing
# print record[13] #for self-testing
# print neg_pos.has_key(record[13]) #for self-testing
#print ep
def mutual_information(data, attr, target_attr):
"""
Calculates the information gain (reduction in entropy) that would
result by splitting the data on the chosen attribute (attr).
"""
subcol_entropy=0
pos_neg = {}
# Calculate the frequency of each of the values in the target attribute
for record in data:
if (pos_neg.has_key(record[attr])):
pos_neg[record[attr]] += 1.0
else:
pos_neg[record[attr]] = 1.0
# Calculate the sum of the entropy for each subset of records weighted
# by their probability of occuring in the training set.
#print pos_neg.keys() #for testing gives t and f
print pos_neg
for values in pos_neg.keys():
pn_prob = pos_neg[values] / sum(pos_neg.values())
subcol = [record for record in data if record[attr] == values]
subcol_entropy += pn_prob * entropy(subcol, target_attr)
#print pn_prob * entropy(subcol, target_attr) #for self-testing
#print subcol_entropy #for self-testing
mutual_information= entropy(data,target_attr) - subcol_entropy
return(mutual_information)
mi=mutual_information(data,5,13)
#print "difference"
#print data[:]
def attr_select(data, attrs, target_attr):
"""
return the attribute with highest information gain
"""
best_mutual_information = 0.0
best_attr = None
count=0
#consider count as attr (it should be an integer value and we can pass a dictionary
for attr in attrs:
mi = mutual_information(data, count, target_attr)
if (count != target_attr):
if (mi >= best_mutual_information):
best_mutual_information = mi
best_attr = count
count+=1 #test the rest of the attributes
return best_attr
ch=attr_select(data,attrs,13)
print ch
def retrieve_examples(data, attr, value):
example_list = []
if not data:
return example_list
else:
record = data.pop()
if record[attr] == value:
example_list.append(record)
example_list.extend(get_examples(data, attr, value))
return example_list
else:
example_list.extend(get_examples(data, attr, value))
return example_list
list=get_examples(data,1,'male')
#print list
これが私が使用しているデータです。.arff 形式です。
@relation cleveland-14-heart-disease
@attribute 'age' real
@attribute 'sex' { female, male}
@attribute 'cp' { typ_angina, asympt, non_anginal, atyp_angina}
@attribute 'trestbps' real
@attribute 'chol' real
@attribute 'fbs' { t, f}
@attribute 'restecg' { left_vent_hyper, normal, st_t_wave_abnormality}
@attribute 'thalach' real
@attribute 'exang' { no, yes}
@attribute 'oldpeak' real
@attribute 'slope' { up, flat, down}
@attribute 'ca' real
@attribute 'thal' { fixed_defect, normal, reversable_defect}
@attribute 'class' { negative, positive}
@data
63,male,typ_angina,145,233,t,left_vent_hyper,150,no,2.3,down,0,fixed_defect, positive
37,male,non_anginal,130,250,f,normal,187,no,3.5,down,0,normal,negative
41,female,atyp_angina,130,204,f,left_vent_hyper,172,no,1.4,up,0,normal,negative
56,male,atyp_angina,120,236,f,normal,178,no,0.8,up,0,normal,negative
57,female,asympt,120,354,f,normal,163,yes,0.6,up,0,normal,positive
57,male,asympt,140,192,f,normal,148,no,0.4,flat,0,fixed_defect,negative
56,female,atyp_angina,140,294,f,left_vent_hyper,153,no,1.3,flat,0,normal,negative
44,male,atyp_angina,120,263,f,normal,173,no,0,up,0,reversable_defect,negative
52,male,non_anginal,172,199,t,normal,162,no,0.5,up,0,reversable_defect,negative
57,male,non_anginal,150,168,f,normal,174,no,1.6,up,0,normal,negative
54,male,asympt,140,239,f,normal,160,no,1.2,up,0,normal,negative
48,female,non_anginal,130,275,f,normal,139,no,0.2,up,0,normal,positive
また、これは私がウェブで見つけたarffパーサーで、正常に動作します。私は同じディレクトリに入れました:
from __future__ import division
"""
Operations on WEKA .arff files
Created on 28/09/2010
@author: peter
"""
import sys, re, os, datetime
def getAttributeByName_(attributes, name):
""" Return attributes member with name <name> """
for a in attributes:
if a['name'] == name:
return a
return None
def showAttributeByName_(attributes, name, title):
print '>>>', title, ':', getAttributeByName(attributes, name)
def debugAttributes(attributes, title):
pass
# showAttributeByName(attributes, 'Number.of.Successful.Grant', title)
def writeArff2(filename, comments, relation, attr_keys, attrs, data, make_copies = False):
""" Write a WEKA .arff file
Params:
filename: name of .arff file
comments: free text comments
relation: name of data set
attr_keys: gives order of keys in attrs to match columns
attrs: dict of attribute: all values of attribute
data: the actual data
"""
assert(len(attr_keys) == len(attrs))
assert(len(data[0]) == len(attrs))
assert(len(attrs) >= 2)
f = file(filename, 'w')
f.write('\n')
f.write('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n')
f.write('%% %s \n' % os.path.basename(filename))
f.write('%\n')
f.write('% Created by ' + os.path.basename(sys.argv[0]) + ' on ' + datetime.date.today().strftime("%A, %d %B %Y") + '\n')
f.write('% Code at http://bit.ly/read_arff\n')
f.write('%\n')
f.write('%% %d instances\n' % len(data))
f.write('%% %d attributes + 1 class = %d columns\n' % (len(data[0]) - 1, len(data[0])))
f.write('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n')
f.write('\n')
if comments:
f.write('% Original comments\n')
for c in comments:
f.write(c + '\n')
f.write('@RELATION ' + relation + '\n\n')
for name in attr_keys:
vals = attrs[name]
if type(vals) is str:
attrs_str = vals
else:
attrs_str = '{%s}' % ','.join([x for x in vals if not x == '?'])
f.write('@ATTRIBUTE %-15s %s\n' % (name, attrs_str))
f.write('\n@DATA\n\n')
for instance in data:
instance = ['?' if x == '' else x for x in instance]
for i,name in enumerate(attr_keys):
if type(attrs) is list:
assert(instance[i] in attrs[name]+ ['?'])
f.write(', '.join(instance) + '\n')
#print ', '.join(instance)
f.close()
#print attr_keys[0], attrs[attr_keys[0]]
#exit()
if make_copies:
""" Copy .arff files to .arff.txt so they can be viewed from Google docs """
print 'writeArff:', filename + '.txt', '-- duplicate'
shutil.copyfile(filename, filename + '.txt')
def writeArff(filename, comments, relation, attrs_list, data, make_copies = False, test = True):
""" Write a WEKA .arff file
Params:
filename: name of .arff file
comments: free text comments
relation: name of data set
attrs_list: list of dicts of attribute: all values of attribute
data: the actual data
"""
assert(len(attrs_list) > 0)
assert(len(data) > 0)
debugAttributes(attrs_list, 'writeArff')
attr_keys = [x['name'] for x in attrs_list]
attrs_dict = {}
for x in attrs_list:
attrs_dict[x['name']] = x['vals']
writeArff2(filename, comments, relation, attr_keys, attrs_dict, data, make_copies)
if test:
out_relation, out_comments, out_attrs_list, out_data = readArff(filename)
if out_attrs_list != attrs_list:
print 'len(out_attrs_list) =', len(out_attrs_list), ', len(attrs_list) =', len(attrs_list)
if len(out_attrs_list) == len(attrs_list):
for i in range(len(attrs_list)):
print '%3d:' % i, out_attrs_list[i], attrs_list[i]
assert(out_relation == relation)
assert(out_attrs_list == attrs_list)
assert(out_data == data)
def getRe(pattern, text):
return re.findall(pattern, text)
relation_pattern = re.compile(r'@RELATION\s*(\S+)\s*$', re.IGNORECASE)
attr_name_pattern = re.compile(r'@ATTRIBUTE\s*(\S+)\s*', re.IGNORECASE)
attr_type_pattern = re.compile(r'@ATTRIBUTE\s*\S+\s*(\S+)', re.IGNORECASE)
attr_vals_pattern = re.compile(r'\{\s*(.+)\s*\}', re.IGNORECASE)
csv_pattern = re.compile(r'(?:^|,)(\"(?:[^\"]+|\"\")*\"|[^,]*)', re.IGNORECASE)
def readArff(filename):
""" Read a WEKA .arff file
Params:
filename: name of .arff file
Returns:
comments: free text comments
relation: name of data set
attrs: list of attributes
data: the actual data
"""
print 'readArff(%s)' % filename
lines = file(filename).readlines()
lines = [l.rstrip('\n').strip() for l in lines]
lines = [l for l in lines if len(l)]
comments = [l for l in lines if l[0] == '%']
lines = [l for l in lines if not l[0] == '%']
relation = [l for l in lines if '@RELATION' in l.upper()]
attributes = [l for l in lines if '@ATTRIBUTE' in l.upper()]
#for i,a in enumerate(attributes[8:12]):
# print '%4d' % (8+i), a
data = []
in_data = False
for l in lines:
if in_data:
data.append(l)
elif '@DATA' in l.upper():
in_data = True
#print 'relation =', relation
out_relation = getRe(relation_pattern, relation[0])[0]
out_attrs = []
for l in attributes:
name = getRe(attr_name_pattern, l)[0]
if not '{' in l:
vals_string = getRe(attr_type_pattern, l)[0].strip()
vals = vals_string.strip()
else:
vals_string = getRe(attr_vals_pattern, l)[0]
vals = [x.strip() for x in vals_string.split(',')]
out_attrs.append({'name':name, 'vals':vals})
if False:
print name, vals
if name == 'Number.of.Successful.Grant':
exit()
#print 'out_attrs:', out_attrs
out_data = []
for l in data:
out_data.append([x.strip() for x in getRe(csv_pattern, l)])
for d in out_data:
assert(len(out_attrs) == len(d))
debugAttributes(out_attrs, 'readArff')
return (out_relation, comments, out_attrs, out_data)
def testCsv():
if len(sys.argv) != 2:
print "Usage: arff.py <arff-file>"
sys.exit()
in_file_name = sys.argv[1]
out_file_name = os.path.splitext(in_file_name)[0] + '.copy' + os.path.splitext(in_file_name)[1]
print 'Reading', in_file_name
print 'Writing', out_file_name
relation, comments, attrs, data = readArff(in_file_name)
writeArff(out_file_name, comments, relation, attrs, data)
if __name__ == '__main__':
if True:
line = '1,a,"x,y",q'
pattern = '(?:^|,)(\\\"(?:[^\\\"]+|\\\"\\\")*\\\"|[^,]*)'
patter2 = r'(?:^|,)(\"(?:[^\"]+|\"\")*\"|[^,]*)'
print pattern
print patter2
assert(patter2 == pattern)
vals = re.findall(pattern, line)
print pattern
print line
print vals
if True:
testCsv()